1use std::collections::{HashMap, HashSet};
2
3use crate::declaration::{Declaration, Module};
4use crate::expression::{Block, Expression, ExpressionKind, Statement};
5use crate::operator::BinaryOperator;
6use crate::types::Type;
7
8pub fn validate_module(module: &Module) -> Result<(), String> {
13 let mut newtypes: HashMap<String, Type> = HashMap::new();
14 for declaration in module {
15 if let Declaration::Type(newtype) = declaration {
16 newtypes.insert(newtype.name.clone(), newtype.inner_type.clone());
17 }
18 }
19 let context = ValidationContext { newtypes };
20 for declaration in module {
21 match declaration {
22 Declaration::Function(function) => {
23 context.validate_block(&function.body)?;
24 }
25 Declaration::Constant(constant) => {
26 context.validate_expression(&constant.value)?;
27 }
28 _ => {}
29 }
30 }
31 Ok(())
32}
33
34struct ValidationContext {
35 newtypes: HashMap<String, Type>,
36}
37
38impl ValidationContext {
39 fn validate_block(&self, block: &Block) -> Result<(), String> {
40 for statement in &block.statements {
41 self.validate_statement(statement)?;
42 }
43 if let Some(result) = &block.result {
44 self.validate_expression(result)?;
45 }
46 Ok(())
47 }
48
49 fn validate_statement(&self, statement: &Statement) -> Result<(), String> {
50 match statement {
51 Statement::Expression(expression) | Statement::Return(Some(expression)) => {
52 self.validate_expression(expression)?;
53 }
54 Statement::Let { value, .. } => {
55 self.validate_expression(value)?;
56 }
57 Statement::Assign(target, value) => {
58 self.validate_expression(target)?;
59 self.validate_expression(value)?;
60 Self::check_replace_types(target, value)?;
61 }
62 Statement::Label {
63 initial_arguments, ..
64 } => {
65 for argument in initial_arguments {
66 self.validate_expression(argument)?;
67 }
68 }
69 Statement::Jump { arguments, .. } => {
70 for argument in arguments {
71 self.validate_expression(argument)?;
72 }
73 }
74 Statement::MultiReplace {
75 targets, values, ..
76 } => {
77 for target in targets {
78 self.validate_expression(target)?;
79 }
80 for value in values {
81 self.validate_expression(value)?;
82 }
83 }
84 Statement::Defer(inner) => {
85 self.validate_statement(inner)?;
86 }
87 Statement::Return(None) => {}
88 }
89 Ok(())
90 }
91
92 fn validate_expression(&self, expression: &Expression) -> Result<(), String> {
93 match &expression.kind {
94 ExpressionKind::BinaryOperation(operator, left, right) => {
95 self.validate_expression(left)?;
96 self.validate_expression(right)?;
97 self.check_binary_operands(operator, left, right)?;
98 }
99 ExpressionKind::TypeConstruction(name, fields) => {
100 for (_, value) in fields {
101 self.validate_expression(value)?;
102 }
103 self.check_construction_fields(name, fields)?;
104 }
105 ExpressionKind::Replace(target, value) | ExpressionKind::OpAssign(_, target, value) => {
106 self.validate_expression(target)?;
107 self.validate_expression(value)?;
108 Self::check_replace_types(target, value)?;
109 }
110 ExpressionKind::Call(callee, arguments) => {
111 self.validate_expression(callee)?;
112 for argument in arguments {
113 self.validate_expression(argument)?;
114 }
115 }
116 ExpressionKind::UnaryOperation(_, operand)
117 | ExpressionKind::Dereference(operand)
118 | ExpressionKind::Convert(operand, _)
119 | ExpressionKind::Transmute(operand, _) => {
120 self.validate_expression(operand)?;
121 }
122 ExpressionKind::Field(object, _) => {
123 self.validate_expression(object)?;
124 }
125 ExpressionKind::Index(object, index) => {
126 self.validate_expression(object)?;
127 self.validate_expression(index)?;
128 }
129 ExpressionKind::ArrayLiteral(elements)
130 | ExpressionKind::TupleLiteral(elements)
131 | ExpressionKind::Print(elements) => {
132 for element in elements {
133 self.validate_expression(element)?;
134 }
135 }
136 ExpressionKind::Block(block) => {
137 self.validate_block(block)?;
138 }
139 ExpressionKind::If {
140 condition,
141 then_branch,
142 else_branch,
143 } => {
144 self.validate_expression(condition)?;
145 self.validate_block(then_branch)?;
146 if let Some(else_branch) = else_branch {
147 self.validate_block(else_branch)?;
148 }
149 }
150 ExpressionKind::Match { value, arms } => {
151 self.validate_expression(value)?;
152 for arm in arms {
153 self.validate_block(&arm.body)?;
154 }
155 }
156 ExpressionKind::Slice(array, start, end) => {
157 self.validate_expression(array)?;
158 if let Some(start) = start {
159 self.validate_expression(start)?;
160 }
161 if let Some(end) = end {
162 self.validate_expression(end)?;
163 }
164 }
165 ExpressionKind::Literal(_)
166 | ExpressionKind::Variable(_)
167 | ExpressionKind::SizeOf(_) => {}
168 }
169 Ok(())
170 }
171
172 fn resolve_underlying(&self, resolved_type: &Type) -> Type {
173 match resolved_type {
174 Type::Named(name) => self.newtypes.get(name).map_or_else(
175 || resolved_type.clone(),
176 |inner| self.resolve_underlying(inner),
177 ),
178 Type::Pointer(mutability, inner) => {
179 Type::Pointer(*mutability, Box::new(self.resolve_underlying(inner)))
180 }
181 other => other.clone(),
182 }
183 }
184
185 fn check_binary_operands(
186 &self,
187 operator: &BinaryOperator,
188 left: &Expression,
189 right: &Expression,
190 ) -> Result<(), String> {
191 if matches!(operator, BinaryOperator::Logical(_)) {
192 return Ok(());
193 }
194 let (Some(left_type), Some(right_type)) = (&left.resolved_type, &right.resolved_type)
195 else {
196 return Ok(());
197 };
198 if left_type == right_type {
199 return Ok(());
200 }
201 let left_resolved = self.resolve_underlying(left_type);
202 let right_resolved = self.resolve_underlying(right_type);
203 if left_resolved != right_resolved {
204 return Err(format!(
205 "type mismatch in '{operator}': left is {left_type}, right is {right_type}",
206 ));
207 }
208 if matches!(left_type, Type::Named(_)) && matches!(right_type, Type::Named(_)) {
209 return Err(format!(
210 "cannot mix distinct types in '{operator}': left is {left_type}, right is {right_type}",
211 ));
212 }
213 Ok(())
214 }
215
216 fn check_replace_types(target: &Expression, value: &Expression) -> Result<(), String> {
217 let Some(target_resolved) = &target.resolved_type else {
218 return Ok(());
219 };
220 let Some(value_type) = &value.resolved_type else {
221 return Ok(());
222 };
223 let target_type = match target_resolved {
224 Type::Pointer(_, inner) => inner.as_ref(),
225 other => other,
226 };
227 if target_type == value_type {
228 return Ok(());
229 }
230 if matches!(target_type, Type::Named(_)) && matches!(value_type, Type::Named(_)) {
231 return Err(format!(
232 "type mismatch in assignment: target is {target_type}, value is {value_type}",
233 ));
234 }
235 Ok(())
236 }
237
238 fn check_construction_fields(
239 &self,
240 type_name: &str,
241 fields: &[(String, Expression)],
242 ) -> Result<(), String> {
243 let Some(inner) = self.newtypes.get(type_name) else {
244 return Ok(());
245 };
246 let expected_fields: Vec<&str> = match inner {
247 Type::Tuple(field_types) => field_types
248 .iter()
249 .filter_map(|field_type| match field_type {
250 Type::Named(name) => Some(name.as_str()),
251 _ => None,
252 })
253 .collect(),
254 Type::Named(name) => vec![name.as_str()],
255 _ => return Ok(()),
256 };
257
258 let mut seen = HashSet::new();
259 for (field_name, _) in fields {
260 if !expected_fields.contains(&field_name.as_str()) {
261 return Err(format!("'{type_name}' has no field '{field_name}'"));
262 }
263 if !seen.insert(field_name.as_str()) {
264 return Err(format!(
265 "duplicate field '{field_name}' in '{type_name}' construction"
266 ));
267 }
268 }
269 for expected in &expected_fields {
270 if !seen.contains(expected) {
271 return Err(format!(
272 "missing field '{expected}' in '{type_name}' construction"
273 ));
274 }
275 }
276 Ok(())
277 }
278}