1use crate::ast::{Program, Statement, WordDef};
4use crate::types::{
5 Effect, SideEffect, StackType, Type, UnionTypeInfo, VariantFieldInfo, VariantInfo,
6};
7use crate::unification::{Subst, unify_stacks};
8
9use super::{TypeChecker, format_line_prefix, validate_main_effect};
10
11impl TypeChecker {
12 pub fn check_program(&mut self, program: &Program) -> Result<(), String> {
13 for union_def in &program.unions {
15 let variants = union_def
16 .variants
17 .iter()
18 .map(|v| VariantInfo {
19 name: v.name.clone(),
20 fields: v
21 .fields
22 .iter()
23 .map(|f| VariantFieldInfo {
24 name: f.name.clone(),
25 field_type: self.parse_type_name(&f.type_name),
26 })
27 .collect(),
28 })
29 .collect();
30
31 self.unions.insert(
32 union_def.name.clone(),
33 UnionTypeInfo {
34 name: union_def.name.clone(),
35 variants,
36 },
37 );
38 }
39
40 self.validate_union_field_types(program)?;
42
43 for word in &program.words {
46 if let Some(effect) = &word.effect {
47 self.validate_effect_types(effect, &word.name)?;
51 self.env.insert(word.name.clone(), effect.clone());
52 } else {
53 return Err(format!(
54 "Word '{}' is missing a stack effect declaration.\n\
55 All words must declare their stack effect, e.g.: : {} ( -- ) ... ;",
56 word.name, word.name
57 ));
58 }
59 }
60
61 if let Some(main_effect) = self.env.get("main") {
64 validate_main_effect(main_effect)?;
65 }
66
67 for word in &program.words {
69 self.check_word(word)?;
70 }
71
72 Ok(())
73 }
74
75 pub(super) fn check_word(&self, word: &WordDef) -> Result<(), String> {
77 let line = word.source.as_ref().map(|s| s.start_line);
79 *self.current_word.borrow_mut() = Some((word.name.clone(), line));
80
81 *self.current_aux_stack.borrow_mut() = StackType::Empty;
83
84 let declared_effect = word.effect.as_ref().expect("word must have effect");
86
87 if let Some((_rest, top_type)) = declared_effect.outputs.clone().pop()
90 && matches!(top_type, Type::Quotation(_) | Type::Closure { .. })
91 {
92 *self.expected_quotation_type.borrow_mut() = Some(top_type);
93 }
94
95 let (result_stack, _subst, inferred_effects) =
97 self.infer_statements_from(&word.body, &declared_effect.inputs, true)?;
98
99 *self.expected_quotation_type.borrow_mut() = None;
101
102 let line_info = line.map(format_line_prefix).unwrap_or_default();
104 unify_stacks(&declared_effect.outputs, &result_stack).map_err(|e| {
105 format!(
106 "{}Word '{}': declared output stack ({}) doesn't match inferred ({}): {}",
107 line_info, word.name, declared_effect.outputs, result_stack, e
108 )
109 })?;
110
111 for inferred in &inferred_effects {
115 if !self.effect_matches_any(inferred, &declared_effect.effects) {
116 return Err(format!(
117 "{}Word '{}': body produces effect '{}' but no matching effect is declared.\n\
118 Hint: Add '| Yield <type>' to the word's stack effect declaration.",
119 line_info, word.name, inferred
120 ));
121 }
122 }
123
124 for declared in &declared_effect.effects {
127 if !self.effect_matches_any(declared, &inferred_effects) {
128 return Err(format!(
129 "{}Word '{}': declares effect '{}' but body doesn't produce it.\n\
130 Hint: Remove the effect declaration or ensure the body uses yield.",
131 line_info, word.name, declared
132 ));
133 }
134 }
135
136 let aux_stack = self.current_aux_stack.borrow().clone();
138 if aux_stack != StackType::Empty {
139 return Err(format!(
140 "{}Word '{}': aux stack is not empty at word return.\n\
141 Remaining aux stack: {}\n\
142 Every >aux must be matched by a corresponding aux> before the word returns.",
143 line_info, word.name, aux_stack
144 ));
145 }
146
147 *self.current_word.borrow_mut() = None;
149
150 Ok(())
151 }
152
153 pub(super) fn infer_statements_from(
160 &self,
161 statements: &[Statement],
162 start_stack: &StackType,
163 capture_stmt_types: bool,
164 ) -> Result<(StackType, Subst, Vec<SideEffect>), String> {
165 let mut current_stack = start_stack.clone();
166 let mut accumulated_subst = Subst::empty();
167 let mut accumulated_effects: Vec<SideEffect> = Vec::new();
168 let mut skip_next = false;
169
170 for (i, stmt) in statements.iter().enumerate() {
171 if skip_next {
173 skip_next = false;
174 continue;
175 }
176
177 if let Statement::IntLiteral(n) = stmt
180 && let Some(Statement::WordCall {
181 name: next_word, ..
182 }) = statements.get(i + 1)
183 {
184 if next_word == "pick" {
185 let (new_stack, subst) = self.handle_literal_pick(*n, current_stack.clone())?;
186 current_stack = new_stack;
187 accumulated_subst = accumulated_subst.compose(&subst);
188 skip_next = true; continue;
190 } else if next_word == "roll" {
191 let (new_stack, subst) = self.handle_literal_roll(*n, current_stack.clone())?;
192 current_stack = new_stack;
193 accumulated_subst = accumulated_subst.compose(&subst);
194 skip_next = true; continue;
196 }
197 }
198
199 let saved_expected_type = if matches!(stmt, Statement::Quotation { .. }) {
202 let saved = self.expected_quotation_type.borrow().clone();
204
205 if let Some(Statement::WordCall {
207 name: next_word, ..
208 }) = statements.get(i + 1)
209 {
210 if let Some(next_effect) = self.lookup_word_effect(next_word) {
212 if let Some((_rest, quot_type)) = next_effect.inputs.clone().pop()
215 && matches!(quot_type, Type::Quotation(_))
216 {
217 *self.expected_quotation_type.borrow_mut() = Some(quot_type);
218 }
219 }
220 }
221 Some(saved)
222 } else {
223 None
224 };
225
226 if capture_stmt_types && let Some((word_name, _)) = self.current_word.borrow().as_ref()
230 {
231 self.capture_statement_type(word_name, i, ¤t_stack);
232 }
233
234 let (new_stack, subst, effects) = self.infer_statement(stmt, current_stack)?;
235 current_stack = new_stack;
236 accumulated_subst = accumulated_subst.compose(&subst);
237
238 for effect in effects {
240 if !accumulated_effects.contains(&effect) {
241 accumulated_effects.push(effect);
242 }
243 }
244
245 if let Some(saved) = saved_expected_type {
247 *self.expected_quotation_type.borrow_mut() = saved;
248 }
249 }
250
251 Ok((current_stack, accumulated_subst, accumulated_effects))
252 }
253
254 pub(super) fn infer_statements(&self, statements: &[Statement]) -> Result<Effect, String> {
265 let start = StackType::RowVar("input".to_string());
266 let (result, subst, effects) = self.infer_statements_from(statements, &start, false)?;
268
269 let normalized_start = subst.apply_stack(&start);
272 let normalized_result = subst.apply_stack(&result);
273
274 Ok(Effect::with_effects(
275 normalized_start,
276 normalized_result,
277 effects,
278 ))
279 }
280
281 pub(super) fn infer_statement(
283 &self,
284 statement: &Statement,
285 current_stack: StackType,
286 ) -> Result<(StackType, Subst, Vec<SideEffect>), String> {
287 match statement {
288 Statement::IntLiteral(_) => Ok((current_stack.push(Type::Int), Subst::empty(), vec![])),
289 Statement::BoolLiteral(_) => {
290 Ok((current_stack.push(Type::Bool), Subst::empty(), vec![]))
291 }
292 Statement::StringLiteral(_) => {
293 Ok((current_stack.push(Type::String), Subst::empty(), vec![]))
294 }
295 Statement::FloatLiteral(_) => {
296 Ok((current_stack.push(Type::Float), Subst::empty(), vec![]))
297 }
298 Statement::Symbol(_) => Ok((current_stack.push(Type::Symbol), Subst::empty(), vec![])),
299 Statement::Match { arms, span } => self.infer_match(arms, span, current_stack),
300 Statement::WordCall { name, span } => self.infer_word_call(name, span, current_stack),
301 Statement::If {
302 then_branch,
303 else_branch,
304 span,
305 } => self.infer_if(then_branch, else_branch, span, current_stack),
306 Statement::Quotation { id, body, .. } => self.infer_quotation(*id, body, current_stack),
307 }
308 }
309
310 pub(super) fn effect_matches_any(
312 &self,
313 inferred: &SideEffect,
314 declared: &[SideEffect],
315 ) -> bool {
316 declared.iter().any(|decl| match (inferred, decl) {
317 (SideEffect::Yield(_), SideEffect::Yield(_)) => true,
318 })
319 }
320}