1use crate::diagnostics::CompileError;
11use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
12use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
13use crate::hir::stmt::Stmt;
14use crate::hir::types::{AddressSpace, Type};
15
16pub fn parse_typescript(source: &str) -> Result<Kernel, CompileError> {
22 let lines: Vec<&str> = source.lines().collect();
23 let mut parser = TsParser::new(&lines);
24 parser.parse_kernel()
25}
26
27struct TsParser<'a> {
28 lines: &'a [&'a str],
29 pos: usize,
30}
31
32impl<'a> TsParser<'a> {
33 fn new(lines: &'a [&'a str]) -> Self {
34 Self { lines, pos: 0 }
35 }
36
37 fn parse_kernel(&mut self) -> Result<Kernel, CompileError> {
38 while self.pos < self.lines.len() {
39 let line = self.lines[self.pos].trim();
40 if line.starts_with("function ")
41 || line.contains("function ")
42 || line.starts_with("export function")
43 {
44 return self.parse_function();
45 }
46 if line.starts_with("import ")
47 || line.is_empty()
48 || line.starts_with("//")
49 || line.starts_with("kernel(")
50 {
51 self.pos += 1;
52 continue;
53 }
54 self.pos += 1;
55 }
56 Err(CompileError::ParseError {
57 message: "no kernel function found".into(),
58 })
59 }
60
61 fn parse_function(&mut self) -> Result<Kernel, CompileError> {
62 let line = self.lines[self.pos].trim().to_string();
63 let func_pos = line
64 .find("function ")
65 .ok_or_else(|| CompileError::ParseError {
66 message: "expected 'function'".into(),
67 })?;
68 let after_func = &line[func_pos + 9..];
69 let paren_start = after_func
70 .find('(')
71 .ok_or_else(|| CompileError::ParseError {
72 message: "expected '('".into(),
73 })?;
74 let name = after_func[..paren_start].trim().to_string();
75
76 let paren_end = line.find(')').ok_or_else(|| CompileError::ParseError {
77 message: "expected ')'".into(),
78 })?;
79 let params_str = &line[line.find('(').unwrap() + 1..paren_end];
80 let params = Self::parse_params(params_str);
81
82 self.pos += 1;
83
84 let body = self.parse_body()?;
85
86 Ok(Kernel {
87 name,
88 params,
89 body,
90 attributes: KernelAttributes::default(),
91 })
92 }
93
94 fn parse_params(s: &str) -> Vec<KernelParam> {
95 let mut params = Vec::new();
96 for param in s.split(',') {
97 let param = param.trim();
98 if param.is_empty() {
99 continue;
100 }
101 let parts: Vec<&str> = param.splitn(2, ':').collect();
102 let name = parts[0].trim().to_string();
103 let (ty, space) = if parts.len() > 1 {
104 let type_str = parts[1].trim();
105 if type_str.contains("[]") {
106 (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
107 } else {
108 match type_str {
109 "i32" => (Type::I32, AddressSpace::Private),
110 "f32" => (Type::F32, AddressSpace::Private),
111 "f64" => (Type::F64, AddressSpace::Private),
112 "boolean" => (Type::Bool, AddressSpace::Private),
113 _ => (Type::U32, AddressSpace::Private),
114 }
115 }
116 } else {
117 (Type::U32, AddressSpace::Private)
118 };
119 params.push(KernelParam {
120 name,
121 ty,
122 address_space: space,
123 });
124 }
125 params
126 }
127
128 fn parse_body(&mut self) -> Result<Vec<Stmt>, CompileError> {
129 let mut stmts = Vec::new();
130 let mut brace_depth = 0i32;
131
132 while self.pos < self.lines.len() {
133 let line = self.lines[self.pos].trim();
134 if line == "{" {
135 brace_depth += 1;
136 self.pos += 1;
137 continue;
138 }
139 if line == "}" || line == "});" || line.starts_with("})") {
140 if brace_depth <= 0 {
141 self.pos += 1;
142 break;
143 }
144 brace_depth -= 1;
145 self.pos += 1;
146 continue;
147 }
148 if line.is_empty() || line.starts_with("//") {
149 self.pos += 1;
150 continue;
151 }
152
153 if line.starts_with("if ") || line.starts_with("if(") {
154 stmts.push(self.parse_if()?);
155 } else if line.starts_with("const ")
156 || line.starts_with("let ")
157 || line.starts_with("var ")
158 {
159 stmts.push(self.parse_declaration()?);
160 } else if line.contains('=') && !line.contains("==") && !line.contains("!=") {
161 stmts.push(self.parse_assignment()?);
162 } else {
163 self.pos += 1;
164 }
165 }
166 Ok(stmts)
167 }
168
169 fn parse_if(&mut self) -> Result<Stmt, CompileError> {
170 let line = self.lines[self.pos].trim();
171 let cond_start = line.find('(').unwrap_or(3);
172 let cond_end = line.rfind(')').unwrap_or(line.len());
173 let cond_str = &line[cond_start + 1..cond_end];
174 let condition = parse_ts_expr(cond_str)?;
175 self.pos += 1;
176
177 let then_body = self.parse_body()?;
178
179 let else_body =
180 if self.pos < self.lines.len() && self.lines[self.pos].trim().starts_with("else") {
181 self.pos += 1;
182 Some(self.parse_body()?)
183 } else {
184 None
185 };
186
187 Ok(Stmt::If {
188 condition,
189 then_body,
190 else_body,
191 })
192 }
193
194 fn parse_declaration(&mut self) -> Result<Stmt, CompileError> {
195 let line = self.lines[self.pos].trim().trim_end_matches(';');
196 self.pos += 1;
197
198 let clean = line
199 .trim_start_matches("const ")
200 .trim_start_matches("let ")
201 .trim_start_matches("var ");
202
203 let eq_pos = clean.find('=').ok_or_else(|| CompileError::ParseError {
204 message: format!("expected '=' in declaration: {line}"),
205 })?;
206
207 let lhs = clean[..eq_pos].trim();
208 let target = lhs.split(':').next().unwrap_or(lhs).trim().to_string();
209 let value = parse_ts_expr(clean[eq_pos + 1..].trim())?;
210
211 Ok(Stmt::Assign { target, value })
212 }
213
214 fn parse_assignment(&mut self) -> Result<Stmt, CompileError> {
215 let line = self.lines[self.pos].trim().trim_end_matches(';');
216 self.pos += 1;
217
218 let eq_pos = line.find('=').ok_or_else(|| CompileError::ParseError {
219 message: format!("expected '=' in assignment: {line}"),
220 })?;
221
222 let lhs = line[..eq_pos].trim();
223 let rhs = line[eq_pos + 1..].trim();
224 let value = parse_ts_expr(rhs)?;
225
226 if lhs.contains('[') {
227 let bracket_pos = lhs.find('[').unwrap();
228 let bracket_end = lhs.find(']').unwrap();
229 let base_name = lhs[..bracket_pos].trim();
230 let index_str = &lhs[bracket_pos + 1..bracket_end];
231 let base = Expr::Var(base_name.to_string());
232 let index = parse_ts_expr(index_str)?;
233 let offset = Expr::BinOp {
234 op: BinOp::Mul,
235 lhs: Box::new(index),
236 rhs: Box::new(Expr::Literal(Literal::Int(4))),
237 };
238 let addr = Expr::BinOp {
239 op: BinOp::Add,
240 lhs: Box::new(base),
241 rhs: Box::new(offset),
242 };
243 return Ok(Stmt::Store {
244 addr,
245 value,
246 space: AddressSpace::Device,
247 });
248 }
249
250 Ok(Stmt::Assign {
251 target: lhs.to_string(),
252 value,
253 })
254 }
255}
256
257fn parse_ts_expr(s: &str) -> Result<Expr, CompileError> {
258 let s = s.trim();
259
260 for &(op_str, op) in &[(" + ", BinOp::Add), (" - ", BinOp::Sub)] {
261 if let Some(pos) = s.rfind(op_str) {
262 return Ok(Expr::BinOp {
263 op,
264 lhs: Box::new(parse_ts_expr(&s[..pos])?),
265 rhs: Box::new(parse_ts_expr(&s[pos + op_str.len()..])?),
266 });
267 }
268 }
269
270 for &(op_str, op) in &[(" * ", BinOp::Mul), (" / ", BinOp::Div)] {
271 if let Some(pos) = s.rfind(op_str) {
272 return Ok(Expr::BinOp {
273 op,
274 lhs: Box::new(parse_ts_expr(&s[..pos])?),
275 rhs: Box::new(parse_ts_expr(&s[pos + op_str.len()..])?),
276 });
277 }
278 }
279
280 for &(op_str, op) in &[
281 (" < ", BinOp::Lt),
282 (" > ", BinOp::Gt),
283 (" === ", BinOp::Eq),
284 (" !== ", BinOp::Ne),
285 ] {
286 if let Some(pos) = s.rfind(op_str) {
287 return Ok(Expr::BinOp {
288 op,
289 lhs: Box::new(parse_ts_expr(&s[..pos])?),
290 rhs: Box::new(parse_ts_expr(&s[pos + op_str.len()..])?),
291 });
292 }
293 }
294
295 if s.starts_with('(') && s.ends_with(')') {
296 return parse_ts_expr(&s[1..s.len() - 1]);
297 }
298
299 match s {
300 "threadId()" | "thread_id()" => return Ok(Expr::ThreadId(Dimension::X)),
301 "workgroupId()" => return Ok(Expr::WorkgroupId(Dimension::X)),
302 _ => {}
303 }
304
305 if let Some("threadId" | "thread_id") = s.strip_suffix("()") {
306 return Ok(Expr::ThreadId(Dimension::X));
307 }
308
309 if let Some(bracket_pos) = s.find('[') {
310 if s.ends_with(']') {
311 return Ok(Expr::Index {
312 base: Box::new(parse_ts_expr(&s[..bracket_pos])?),
313 index: Box::new(parse_ts_expr(&s[bracket_pos + 1..s.len() - 1])?),
314 });
315 }
316 }
317
318 if let Ok(v) = s.parse::<i64>() {
319 return Ok(Expr::Literal(Literal::Int(v)));
320 }
321 if let Ok(v) = s.parse::<f64>() {
322 return Ok(Expr::Literal(Literal::Float(v)));
323 }
324
325 if s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() {
326 return Ok(Expr::Var(s.to_string()));
327 }
328
329 Err(CompileError::ParseError {
330 message: format!("cannot parse TS expression: '{s}'"),
331 })
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_parse_ts_vector_add() {
340 let source = r#"
341import { kernel, f32, threadId } from "wave";
342
343function vectorAdd(a: f32[], b: f32[], out: f32[], n: u32) {
344 const gid = threadId();
345 if (gid < n) {
346 out[gid] = a[gid] + b[gid];
347 }
348}
349"#;
350 let kernel = parse_typescript(source).unwrap();
351 assert_eq!(kernel.name, "vectorAdd");
352 assert_eq!(kernel.params.len(), 4);
353 assert!(kernel.params[0].ty.is_pointer());
354 }
355
356 #[test]
357 fn test_parse_ts_simple() {
358 let source = r#"
359function test(n: u32) {
360 const x = 42;
361}
362"#;
363 let kernel = parse_typescript(source).unwrap();
364 assert_eq!(kernel.name, "test");
365 assert_eq!(kernel.body.len(), 1);
366 }
367}