1use std::collections::HashMap;
10
11use super::expr::{BuiltinFunc, Expr, Literal, UnaryOp};
12use super::kernel::Kernel;
13use super::stmt::Stmt;
14use super::types::Type;
15use crate::diagnostics::error::CompileError;
16
17struct TypeEnv {
19 scopes: Vec<HashMap<String, Type>>,
20}
21
22impl TypeEnv {
23 fn new() -> Self {
24 Self {
25 scopes: vec![HashMap::new()],
26 }
27 }
28
29 fn define(&mut self, name: &str, ty: Type) {
30 if let Some(scope) = self.scopes.last_mut() {
31 scope.insert(name.to_string(), ty);
32 }
33 }
34
35 fn lookup(&self, name: &str) -> Option<&Type> {
36 for scope in self.scopes.iter().rev() {
37 if let Some(ty) = scope.get(name) {
38 return Some(ty);
39 }
40 }
41 None
42 }
43
44 fn push_scope(&mut self) {
45 self.scopes.push(HashMap::new());
46 }
47
48 fn pop_scope(&mut self) {
49 self.scopes.pop();
50 }
51}
52
53pub fn validate_kernel(kernel: &Kernel) -> Result<(), CompileError> {
59 let mut env = TypeEnv::new();
60
61 for param in &kernel.params {
62 env.define(¶m.name, param.ty.clone());
63 }
64
65 validate_stmts(&kernel.body, &mut env)
66}
67
68fn validate_stmts(stmts: &[Stmt], env: &mut TypeEnv) -> Result<(), CompileError> {
69 for stmt in stmts {
70 validate_stmt(stmt, env)?;
71 }
72 Ok(())
73}
74
75fn validate_stmt(stmt: &Stmt, env: &mut TypeEnv) -> Result<(), CompileError> {
76 match stmt {
77 Stmt::Assign { target, value } => {
78 let ty = infer_type(value, env)?;
79 env.define(target, ty);
80 Ok(())
81 }
82 Stmt::If {
83 condition,
84 then_body,
85 else_body,
86 } => {
87 let cond_ty = infer_type(condition, env)?;
88 if cond_ty != Type::Bool {
89 return Err(CompileError::TypeMismatch {
90 expected: "bool".into(),
91 found: format!("{cond_ty}"),
92 });
93 }
94 env.push_scope();
95 validate_stmts(then_body, env)?;
96 env.pop_scope();
97 if let Some(else_stmts) = else_body {
98 env.push_scope();
99 validate_stmts(else_stmts, env)?;
100 env.pop_scope();
101 }
102 Ok(())
103 }
104 Stmt::For {
105 var,
106 start,
107 end,
108 step,
109 body,
110 } => {
111 infer_type(start, env)?;
112 infer_type(end, env)?;
113 infer_type(step, env)?;
114 env.push_scope();
115 env.define(var, Type::I32);
116 validate_stmts(body, env)?;
117 env.pop_scope();
118 Ok(())
119 }
120 Stmt::While { condition, body } => {
121 let cond_ty = infer_type(condition, env)?;
122 if cond_ty != Type::Bool {
123 return Err(CompileError::TypeMismatch {
124 expected: "bool".into(),
125 found: format!("{cond_ty}"),
126 });
127 }
128 env.push_scope();
129 validate_stmts(body, env)?;
130 env.pop_scope();
131 Ok(())
132 }
133 Stmt::Return { value } => {
134 if let Some(val) = value {
135 infer_type(val, env)?;
136 }
137 Ok(())
138 }
139 Stmt::Store { addr, value, .. } => {
140 infer_type(addr, env)?;
141 infer_type(value, env)?;
142 Ok(())
143 }
144 Stmt::Barrier | Stmt::Fence { .. } => Ok(()),
145 }
146}
147
148fn infer_type(expr: &Expr, env: &TypeEnv) -> Result<Type, CompileError> {
154 match expr {
155 Expr::Var(name) => env
156 .lookup(name)
157 .cloned()
158 .ok_or_else(|| CompileError::UndefinedVariable { name: name.clone() }),
159 Expr::Literal(lit) => Ok(match lit {
160 Literal::Int(_) => Type::I32,
161 Literal::UInt(_) => Type::U32,
162 Literal::Float(_) => Type::F32,
163 Literal::Bool(_) => Type::Bool,
164 }),
165 Expr::BinOp { op, lhs, .. } => {
166 let lhs_ty = infer_type(lhs, env)?;
167 if op.is_comparison() {
168 Ok(Type::Bool)
169 } else {
170 Ok(lhs_ty)
171 }
172 }
173 Expr::UnaryOp { op, operand } => {
174 let operand_ty = infer_type(operand, env)?;
175 match op {
176 UnaryOp::Not => Ok(Type::Bool),
177 UnaryOp::Neg | UnaryOp::BitNot => Ok(operand_ty),
178 }
179 }
180 Expr::Call { func, .. } => Ok(match func {
181 BuiltinFunc::Sqrt
182 | BuiltinFunc::Sin
183 | BuiltinFunc::Cos
184 | BuiltinFunc::Exp2
185 | BuiltinFunc::Log2 => Type::F32,
186 BuiltinFunc::Abs | BuiltinFunc::Min | BuiltinFunc::Max | BuiltinFunc::AtomicAdd => {
187 Type::U32
188 }
189 }),
190 Expr::Index { base, .. } => {
191 let base_ty = infer_type(base, env)?;
192 match base_ty {
193 Type::Array(elem, _) => Ok(*elem),
194 _ => Ok(Type::F32),
195 }
196 }
197 Expr::Cast { to, .. } => Ok(to.clone()),
198 Expr::ThreadId(_)
199 | Expr::WorkgroupId(_)
200 | Expr::WorkgroupSize(_)
201 | Expr::LaneId
202 | Expr::WaveWidth
203 | Expr::Shuffle { .. } => Ok(Type::U32),
204 Expr::Load { .. } => Ok(Type::F32),
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::hir::expr::{BinOp, Dimension};
212 use crate::hir::kernel::{KernelAttributes, KernelParam};
213 use crate::hir::types::AddressSpace;
214
215 #[test]
216 fn test_validate_simple_kernel() {
217 let kernel = Kernel {
218 name: "test".into(),
219 params: vec![KernelParam {
220 name: "n".into(),
221 ty: Type::U32,
222 address_space: AddressSpace::Private,
223 }],
224 body: vec![
225 Stmt::Assign {
226 target: "gid".into(),
227 value: Expr::ThreadId(Dimension::X),
228 },
229 Stmt::If {
230 condition: Expr::BinOp {
231 op: BinOp::Lt,
232 lhs: Box::new(Expr::Var("gid".into())),
233 rhs: Box::new(Expr::Var("n".into())),
234 },
235 then_body: vec![Stmt::Assign {
236 target: "x".into(),
237 value: Expr::Literal(Literal::Int(1)),
238 }],
239 else_body: None,
240 },
241 ],
242 attributes: KernelAttributes::default(),
243 };
244 assert!(validate_kernel(&kernel).is_ok());
245 }
246
247 #[test]
248 fn test_validate_undefined_variable() {
249 let kernel = Kernel {
250 name: "test".into(),
251 params: vec![],
252 body: vec![Stmt::Assign {
253 target: "x".into(),
254 value: Expr::Var("undefined_var".into()),
255 }],
256 attributes: KernelAttributes::default(),
257 };
258 assert!(validate_kernel(&kernel).is_err());
259 }
260
261 #[test]
262 fn test_infer_literal_types() {
263 let env = TypeEnv::new();
264 assert_eq!(
265 infer_type(&Expr::Literal(Literal::Int(42)), &env).unwrap(),
266 Type::I32
267 );
268 assert_eq!(
269 infer_type(&Expr::Literal(Literal::Float(1.0)), &env).unwrap(),
270 Type::F32
271 );
272 assert_eq!(
273 infer_type(&Expr::Literal(Literal::Bool(true)), &env).unwrap(),
274 Type::Bool
275 );
276 }
277
278 #[test]
279 fn test_infer_thread_id_type() {
280 let env = TypeEnv::new();
281 assert_eq!(
282 infer_type(&Expr::ThreadId(Dimension::X), &env).unwrap(),
283 Type::U32
284 );
285 }
286}