Skip to main content

wave_compiler/frontend/
typescript.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! TypeScript kernel parser producing HIR.
5//!
6//! Parses a subset of TypeScript suitable for GPU kernels using
7//! line-based parsing. Supports function definitions with type
8//! annotations, arithmetic, if/else, and array indexing.
9
10use crate::diagnostics::CompileError;
11use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
12use crate::hir::kernel::{Kernel, KernelAttributes, KernelParam};
13use crate::hir::stmt::Stmt;
14use crate::hir::types::{AddressSpace, Type};
15
16/// Parse a TypeScript kernel source string into an HIR Kernel.
17///
18/// # Errors
19///
20/// Returns `CompileError::ParseError` if the source cannot be parsed.
21pub fn parse_typescript(source: &str) -> Result<Kernel, CompileError> {
22    let lines: Vec<&str> = source.lines().collect();
23    let mut parser = TsParser::new(&lines);
24    parser.parse_kernel()
25}
26
27struct TsParser<'a> {
28    lines: &'a [&'a str],
29    pos: usize,
30}
31
32impl<'a> TsParser<'a> {
33    fn new(lines: &'a [&'a str]) -> Self {
34        Self { lines, pos: 0 }
35    }
36
37    fn parse_kernel(&mut self) -> Result<Kernel, CompileError> {
38        while self.pos < self.lines.len() {
39            let line = self.lines[self.pos].trim();
40            if line.starts_with("function ")
41                || line.contains("function ")
42                || line.starts_with("export function")
43            {
44                return self.parse_function();
45            }
46            if line.starts_with("import ")
47                || line.is_empty()
48                || line.starts_with("//")
49                || line.starts_with("kernel(")
50            {
51                self.pos += 1;
52                continue;
53            }
54            self.pos += 1;
55        }
56        Err(CompileError::ParseError {
57            message: "no kernel function found".into(),
58        })
59    }
60
61    fn parse_function(&mut self) -> Result<Kernel, CompileError> {
62        let line = self.lines[self.pos].trim().to_string();
63        let func_pos = line
64            .find("function ")
65            .ok_or_else(|| CompileError::ParseError {
66                message: "expected 'function'".into(),
67            })?;
68        let after_func = &line[func_pos + 9..];
69        let paren_start = after_func
70            .find('(')
71            .ok_or_else(|| CompileError::ParseError {
72                message: "expected '('".into(),
73            })?;
74        let name = after_func[..paren_start].trim().to_string();
75
76        let paren_end = line.find(')').ok_or_else(|| CompileError::ParseError {
77            message: "expected ')'".into(),
78        })?;
79        let params_str = &line[line.find('(').unwrap() + 1..paren_end];
80        let params = Self::parse_params(params_str);
81
82        self.pos += 1;
83
84        let body = self.parse_body()?;
85
86        Ok(Kernel {
87            name,
88            params,
89            body,
90            attributes: KernelAttributes::default(),
91        })
92    }
93
94    fn parse_params(s: &str) -> Vec<KernelParam> {
95        let mut params = Vec::new();
96        for param in s.split(',') {
97            let param = param.trim();
98            if param.is_empty() {
99                continue;
100            }
101            let parts: Vec<&str> = param.splitn(2, ':').collect();
102            let name = parts[0].trim().to_string();
103            let (ty, space) = if parts.len() > 1 {
104                let type_str = parts[1].trim();
105                if type_str.contains("[]") {
106                    (Type::Ptr(AddressSpace::Device), AddressSpace::Device)
107                } else {
108                    match type_str {
109                        "i32" => (Type::I32, AddressSpace::Private),
110                        "f32" => (Type::F32, AddressSpace::Private),
111                        "f64" => (Type::F64, AddressSpace::Private),
112                        "boolean" => (Type::Bool, AddressSpace::Private),
113                        _ => (Type::U32, AddressSpace::Private),
114                    }
115                }
116            } else {
117                (Type::U32, AddressSpace::Private)
118            };
119            params.push(KernelParam {
120                name,
121                ty,
122                address_space: space,
123            });
124        }
125        params
126    }
127
128    fn parse_body(&mut self) -> Result<Vec<Stmt>, CompileError> {
129        let mut stmts = Vec::new();
130        let mut brace_depth = 0i32;
131
132        while self.pos < self.lines.len() {
133            let line = self.lines[self.pos].trim();
134            if line == "{" {
135                brace_depth += 1;
136                self.pos += 1;
137                continue;
138            }
139            if line == "}" || line == "});" || line.starts_with("})") {
140                if brace_depth <= 0 {
141                    self.pos += 1;
142                    break;
143                }
144                brace_depth -= 1;
145                self.pos += 1;
146                continue;
147            }
148            if line.is_empty() || line.starts_with("//") {
149                self.pos += 1;
150                continue;
151            }
152
153            if line.starts_with("if ") || line.starts_with("if(") {
154                stmts.push(self.parse_if()?);
155            } else if line.starts_with("const ")
156                || line.starts_with("let ")
157                || line.starts_with("var ")
158            {
159                stmts.push(self.parse_declaration()?);
160            } else if line.contains('=') && !line.contains("==") && !line.contains("!=") {
161                stmts.push(self.parse_assignment()?);
162            } else {
163                self.pos += 1;
164            }
165        }
166        Ok(stmts)
167    }
168
169    fn parse_if(&mut self) -> Result<Stmt, CompileError> {
170        let line = self.lines[self.pos].trim();
171        let cond_start = line.find('(').unwrap_or(3);
172        let cond_end = line.rfind(')').unwrap_or(line.len());
173        let cond_str = &line[cond_start + 1..cond_end];
174        let condition = parse_ts_expr(cond_str)?;
175        self.pos += 1;
176
177        let then_body = self.parse_body()?;
178
179        let else_body =
180            if self.pos < self.lines.len() && self.lines[self.pos].trim().starts_with("else") {
181                self.pos += 1;
182                Some(self.parse_body()?)
183            } else {
184                None
185            };
186
187        Ok(Stmt::If {
188            condition,
189            then_body,
190            else_body,
191        })
192    }
193
194    fn parse_declaration(&mut self) -> Result<Stmt, CompileError> {
195        let line = self.lines[self.pos].trim().trim_end_matches(';');
196        self.pos += 1;
197
198        let clean = line
199            .trim_start_matches("const ")
200            .trim_start_matches("let ")
201            .trim_start_matches("var ");
202
203        let eq_pos = clean.find('=').ok_or_else(|| CompileError::ParseError {
204            message: format!("expected '=' in declaration: {line}"),
205        })?;
206
207        let lhs = clean[..eq_pos].trim();
208        let target = lhs.split(':').next().unwrap_or(lhs).trim().to_string();
209        let value = parse_ts_expr(clean[eq_pos + 1..].trim())?;
210
211        Ok(Stmt::Assign { target, value })
212    }
213
214    fn parse_assignment(&mut self) -> Result<Stmt, CompileError> {
215        let line = self.lines[self.pos].trim().trim_end_matches(';');
216        self.pos += 1;
217
218        let eq_pos = line.find('=').ok_or_else(|| CompileError::ParseError {
219            message: format!("expected '=' in assignment: {line}"),
220        })?;
221
222        let lhs = line[..eq_pos].trim();
223        let rhs = line[eq_pos + 1..].trim();
224        let value = parse_ts_expr(rhs)?;
225
226        if lhs.contains('[') {
227            let bracket_pos = lhs.find('[').unwrap();
228            let bracket_end = lhs.find(']').unwrap();
229            let base_name = lhs[..bracket_pos].trim();
230            let index_str = &lhs[bracket_pos + 1..bracket_end];
231            let base = Expr::Var(base_name.to_string());
232            let index = parse_ts_expr(index_str)?;
233            let offset = Expr::BinOp {
234                op: BinOp::Mul,
235                lhs: Box::new(index),
236                rhs: Box::new(Expr::Literal(Literal::Int(4))),
237            };
238            let addr = Expr::BinOp {
239                op: BinOp::Add,
240                lhs: Box::new(base),
241                rhs: Box::new(offset),
242            };
243            return Ok(Stmt::Store {
244                addr,
245                value,
246                space: AddressSpace::Device,
247            });
248        }
249
250        Ok(Stmt::Assign {
251            target: lhs.to_string(),
252            value,
253        })
254    }
255}
256
257fn parse_ts_expr(s: &str) -> Result<Expr, CompileError> {
258    let s = s.trim();
259
260    for &(op_str, op) in &[(" + ", BinOp::Add), (" - ", BinOp::Sub)] {
261        if let Some(pos) = s.rfind(op_str) {
262            return Ok(Expr::BinOp {
263                op,
264                lhs: Box::new(parse_ts_expr(&s[..pos])?),
265                rhs: Box::new(parse_ts_expr(&s[pos + op_str.len()..])?),
266            });
267        }
268    }
269
270    for &(op_str, op) in &[(" * ", BinOp::Mul), (" / ", BinOp::Div)] {
271        if let Some(pos) = s.rfind(op_str) {
272            return Ok(Expr::BinOp {
273                op,
274                lhs: Box::new(parse_ts_expr(&s[..pos])?),
275                rhs: Box::new(parse_ts_expr(&s[pos + op_str.len()..])?),
276            });
277        }
278    }
279
280    for &(op_str, op) in &[
281        (" < ", BinOp::Lt),
282        (" > ", BinOp::Gt),
283        (" === ", BinOp::Eq),
284        (" !== ", BinOp::Ne),
285    ] {
286        if let Some(pos) = s.rfind(op_str) {
287            return Ok(Expr::BinOp {
288                op,
289                lhs: Box::new(parse_ts_expr(&s[..pos])?),
290                rhs: Box::new(parse_ts_expr(&s[pos + op_str.len()..])?),
291            });
292        }
293    }
294
295    if s.starts_with('(') && s.ends_with(')') {
296        return parse_ts_expr(&s[1..s.len() - 1]);
297    }
298
299    match s {
300        "threadId()" | "thread_id()" => return Ok(Expr::ThreadId(Dimension::X)),
301        "workgroupId()" => return Ok(Expr::WorkgroupId(Dimension::X)),
302        _ => {}
303    }
304
305    if let Some("threadId" | "thread_id") = s.strip_suffix("()") {
306        return Ok(Expr::ThreadId(Dimension::X));
307    }
308
309    if let Some(bracket_pos) = s.find('[') {
310        if s.ends_with(']') {
311            return Ok(Expr::Index {
312                base: Box::new(parse_ts_expr(&s[..bracket_pos])?),
313                index: Box::new(parse_ts_expr(&s[bracket_pos + 1..s.len() - 1])?),
314            });
315        }
316    }
317
318    if let Ok(v) = s.parse::<i64>() {
319        return Ok(Expr::Literal(Literal::Int(v)));
320    }
321    if let Ok(v) = s.parse::<f64>() {
322        return Ok(Expr::Literal(Literal::Float(v)));
323    }
324
325    if s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() {
326        return Ok(Expr::Var(s.to_string()));
327    }
328
329    Err(CompileError::ParseError {
330        message: format!("cannot parse TS expression: '{s}'"),
331    })
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_parse_ts_vector_add() {
340        let source = r#"
341import { kernel, f32, threadId } from "wave";
342
343function vectorAdd(a: f32[], b: f32[], out: f32[], n: u32) {
344    const gid = threadId();
345    if (gid < n) {
346        out[gid] = a[gid] + b[gid];
347    }
348}
349"#;
350        let kernel = parse_typescript(source).unwrap();
351        assert_eq!(kernel.name, "vectorAdd");
352        assert_eq!(kernel.params.len(), 4);
353        assert!(kernel.params[0].ty.is_pointer());
354    }
355
356    #[test]
357    fn test_parse_ts_simple() {
358        let source = r#"
359function test(n: u32) {
360    const x = 42;
361}
362"#;
363        let kernel = parse_typescript(source).unwrap();
364        assert_eq!(kernel.name, "test");
365        assert_eq!(kernel.body.len(), 1);
366    }
367}