Skip to main content

wave_compiler/hir/
validate.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! HIR validation and type checking for WAVE GPU kernels.
5//!
6//! Validates kernel definitions by checking that all variables are defined
7//! before use, types are consistent, and control flow is well-formed.
8
9use std::collections::HashMap;
10
11use super::expr::{BuiltinFunc, Expr, Literal, UnaryOp};
12use super::kernel::Kernel;
13use super::stmt::Stmt;
14use super::types::Type;
15use crate::diagnostics::error::CompileError;
16
17/// Type environment mapping variable names to their types.
18struct TypeEnv {
19    scopes: Vec<HashMap<String, Type>>,
20}
21
22impl TypeEnv {
23    fn new() -> Self {
24        Self {
25            scopes: vec![HashMap::new()],
26        }
27    }
28
29    fn define(&mut self, name: &str, ty: Type) {
30        if let Some(scope) = self.scopes.last_mut() {
31            scope.insert(name.to_string(), ty);
32        }
33    }
34
35    fn lookup(&self, name: &str) -> Option<&Type> {
36        for scope in self.scopes.iter().rev() {
37            if let Some(ty) = scope.get(name) {
38                return Some(ty);
39            }
40        }
41        None
42    }
43
44    fn push_scope(&mut self) {
45        self.scopes.push(HashMap::new());
46    }
47
48    fn pop_scope(&mut self) {
49        self.scopes.pop();
50    }
51}
52
53/// Validate and type-check a kernel definition.
54///
55/// # Errors
56///
57/// Returns `CompileError` if the kernel has type errors or undefined variables.
58pub fn validate_kernel(kernel: &Kernel) -> Result<(), CompileError> {
59    let mut env = TypeEnv::new();
60
61    for param in &kernel.params {
62        env.define(&param.name, param.ty.clone());
63    }
64
65    validate_stmts(&kernel.body, &mut env)
66}
67
68fn validate_stmts(stmts: &[Stmt], env: &mut TypeEnv) -> Result<(), CompileError> {
69    for stmt in stmts {
70        validate_stmt(stmt, env)?;
71    }
72    Ok(())
73}
74
75fn validate_stmt(stmt: &Stmt, env: &mut TypeEnv) -> Result<(), CompileError> {
76    match stmt {
77        Stmt::Assign { target, value } => {
78            let ty = infer_type(value, env)?;
79            env.define(target, ty);
80            Ok(())
81        }
82        Stmt::If {
83            condition,
84            then_body,
85            else_body,
86        } => {
87            let cond_ty = infer_type(condition, env)?;
88            if cond_ty != Type::Bool {
89                return Err(CompileError::TypeMismatch {
90                    expected: "bool".into(),
91                    found: format!("{cond_ty}"),
92                });
93            }
94            env.push_scope();
95            validate_stmts(then_body, env)?;
96            env.pop_scope();
97            if let Some(else_stmts) = else_body {
98                env.push_scope();
99                validate_stmts(else_stmts, env)?;
100                env.pop_scope();
101            }
102            Ok(())
103        }
104        Stmt::For {
105            var,
106            start,
107            end,
108            step,
109            body,
110        } => {
111            infer_type(start, env)?;
112            infer_type(end, env)?;
113            infer_type(step, env)?;
114            env.push_scope();
115            env.define(var, Type::I32);
116            validate_stmts(body, env)?;
117            env.pop_scope();
118            Ok(())
119        }
120        Stmt::While { condition, body } => {
121            let cond_ty = infer_type(condition, env)?;
122            if cond_ty != Type::Bool {
123                return Err(CompileError::TypeMismatch {
124                    expected: "bool".into(),
125                    found: format!("{cond_ty}"),
126                });
127            }
128            env.push_scope();
129            validate_stmts(body, env)?;
130            env.pop_scope();
131            Ok(())
132        }
133        Stmt::Return { value } => {
134            if let Some(val) = value {
135                infer_type(val, env)?;
136            }
137            Ok(())
138        }
139        Stmt::Store { addr, value, .. } => {
140            infer_type(addr, env)?;
141            infer_type(value, env)?;
142            Ok(())
143        }
144        Stmt::Barrier | Stmt::Fence { .. } => Ok(()),
145    }
146}
147
148/// Infer the type of an expression given a type environment.
149///
150/// # Errors
151///
152/// Returns `CompileError` if a variable is undefined or types are incompatible.
153fn infer_type(expr: &Expr, env: &TypeEnv) -> Result<Type, CompileError> {
154    match expr {
155        Expr::Var(name) => env
156            .lookup(name)
157            .cloned()
158            .ok_or_else(|| CompileError::UndefinedVariable { name: name.clone() }),
159        Expr::Literal(lit) => Ok(match lit {
160            Literal::Int(_) => Type::I32,
161            Literal::UInt(_) => Type::U32,
162            Literal::Float(_) => Type::F32,
163            Literal::Bool(_) => Type::Bool,
164        }),
165        Expr::BinOp { op, lhs, .. } => {
166            let lhs_ty = infer_type(lhs, env)?;
167            if op.is_comparison() {
168                Ok(Type::Bool)
169            } else {
170                Ok(lhs_ty)
171            }
172        }
173        Expr::UnaryOp { op, operand } => {
174            let operand_ty = infer_type(operand, env)?;
175            match op {
176                UnaryOp::Not => Ok(Type::Bool),
177                UnaryOp::Neg | UnaryOp::BitNot => Ok(operand_ty),
178            }
179        }
180        Expr::Call { func, .. } => Ok(match func {
181            BuiltinFunc::Sqrt
182            | BuiltinFunc::Sin
183            | BuiltinFunc::Cos
184            | BuiltinFunc::Exp2
185            | BuiltinFunc::Log2 => Type::F32,
186            BuiltinFunc::Abs | BuiltinFunc::Min | BuiltinFunc::Max | BuiltinFunc::AtomicAdd => {
187                Type::U32
188            }
189        }),
190        Expr::Index { base, .. } => {
191            let base_ty = infer_type(base, env)?;
192            match base_ty {
193                Type::Array(elem, _) => Ok(*elem),
194                _ => Ok(Type::F32),
195            }
196        }
197        Expr::Cast { to, .. } => Ok(to.clone()),
198        Expr::ThreadId(_)
199        | Expr::WorkgroupId(_)
200        | Expr::WorkgroupSize(_)
201        | Expr::LaneId
202        | Expr::WaveWidth
203        | Expr::Shuffle { .. } => Ok(Type::U32),
204        Expr::Load { .. } => Ok(Type::F32),
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::hir::expr::{BinOp, Dimension};
212    use crate::hir::kernel::{KernelAttributes, KernelParam};
213    use crate::hir::types::AddressSpace;
214
215    #[test]
216    fn test_validate_simple_kernel() {
217        let kernel = Kernel {
218            name: "test".into(),
219            params: vec![KernelParam {
220                name: "n".into(),
221                ty: Type::U32,
222                address_space: AddressSpace::Private,
223            }],
224            body: vec![
225                Stmt::Assign {
226                    target: "gid".into(),
227                    value: Expr::ThreadId(Dimension::X),
228                },
229                Stmt::If {
230                    condition: Expr::BinOp {
231                        op: BinOp::Lt,
232                        lhs: Box::new(Expr::Var("gid".into())),
233                        rhs: Box::new(Expr::Var("n".into())),
234                    },
235                    then_body: vec![Stmt::Assign {
236                        target: "x".into(),
237                        value: Expr::Literal(Literal::Int(1)),
238                    }],
239                    else_body: None,
240                },
241            ],
242            attributes: KernelAttributes::default(),
243        };
244        assert!(validate_kernel(&kernel).is_ok());
245    }
246
247    #[test]
248    fn test_validate_undefined_variable() {
249        let kernel = Kernel {
250            name: "test".into(),
251            params: vec![],
252            body: vec![Stmt::Assign {
253                target: "x".into(),
254                value: Expr::Var("undefined_var".into()),
255            }],
256            attributes: KernelAttributes::default(),
257        };
258        assert!(validate_kernel(&kernel).is_err());
259    }
260
261    #[test]
262    fn test_infer_literal_types() {
263        let env = TypeEnv::new();
264        assert_eq!(
265            infer_type(&Expr::Literal(Literal::Int(42)), &env).unwrap(),
266            Type::I32
267        );
268        assert_eq!(
269            infer_type(&Expr::Literal(Literal::Float(1.0)), &env).unwrap(),
270            Type::F32
271        );
272        assert_eq!(
273            infer_type(&Expr::Literal(Literal::Bool(true)), &env).unwrap(),
274            Type::Bool
275        );
276    }
277
278    #[test]
279    fn test_infer_thread_id_type() {
280        let env = TypeEnv::new();
281        assert_eq!(
282            infer_type(&Expr::ThreadId(Dimension::X), &env).unwrap(),
283            Type::U32
284        );
285    }
286}