spade_typeinference/
expression.rs

1use num::{BigInt, One};
2use spade_common::location_info::{Loc, WithLocation};
3use spade_common::name::Identifier;
4use spade_common::num_ext::InfallibleToBigInt;
5use spade_diagnostics::diagnostic::DiagnosticLevel;
6use spade_diagnostics::{diag_anyhow, Diagnostic};
7use spade_hir::expression::{BinaryOperator, IntLiteralKind, NamedArgument, UnaryOperator};
8use spade_hir::{ExprKind, Expression};
9use spade_macros::trace_typechecker;
10use spade_types::meta_types::MetaType;
11use spade_types::KnownType;
12
13use crate::constraints::{bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource};
14use crate::equation::{TypeVar, TypedExpression};
15use crate::error::{TypeMismatch as Tm, UnificationErrorExt};
16use crate::fixed_types::{t_bit, t_bool};
17use crate::requirements::{ConstantInt, Requirement};
18use crate::{Context, GenericListToken, HasType, Result, TraceStackEntry, TypeState};
19
20macro_rules! assuming_kind {
21    ($pattern:pat = $expr:expr => $block:block) => {
22        if let $pattern = &$expr.inner.kind {
23            $block
24        } else {
25            panic!("Incorrect assumption about expression kind")
26        };
27    };
28}
29
30impl TypeState {
31    #[trace_typechecker]
32    #[tracing::instrument(level = "trace", skip_all)]
33    pub fn visit_identifier(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
34        assuming_kind!(ExprKind::Identifier(ident) = &expression => {
35            // Add an equation for the anonymous id
36            self.unify_expression_generic_error(
37                expression,
38                &TypedExpression::Name(ident.clone()),
39                ctx
40            )?;
41        });
42        Ok(())
43    }
44
45    #[trace_typechecker]
46    #[tracing::instrument(level = "trace", skip_all)]
47    pub fn visit_type_level_integer(
48        &mut self,
49        expression: &Loc<Expression>,
50        generic_list: &GenericListToken,
51        ctx: &Context,
52    ) -> Result<()> {
53        assuming_kind!(ExprKind::TypeLevelInteger(value) = &expression => {
54            let (t, _size) = self.new_generic_number(expression.loc(), ctx);
55            self.unify(&t, &expression.inner, ctx)
56                .into_diagnostic(expression.loc(), |diag, Tm{e: _, g: _got}| {
57                    diag
58                        .level(DiagnosticLevel::Bug)
59                        .message("Failed to unify integer literal with integer")
60                })?;
61            let generic = self.get_generic_list(generic_list).get(value).ok_or_else(|| {
62                Diagnostic::bug(expression, "Found no generic list when visiting {value}")
63            })?;
64            self.add_requirement(
65                Requirement::FitsIntLiteral {
66                    value: ConstantInt::Generic(generic.clone()),
67                    target_type: t.at_loc(expression)
68                }
69            )
70        });
71        Ok(())
72    }
73
74    #[tracing::instrument(level = "trace", skip_all)]
75    pub fn visit_pipeline_ref(
76        &mut self,
77        expression: &Loc<Expression>,
78        generic_list: &GenericListToken,
79        ctx: &Context,
80    ) -> Result<()> {
81        assuming_kind!(ExprKind::PipelineRef{stage, name, declares_name, depth_typeexpr_id} = &expression => {
82            // If this reference declares the referenced name, add a new equation
83            if *declares_name {
84                let new_var = self.new_generic_type(expression.loc());
85                self.add_equation(TypedExpression::Name(name.clone().inner), new_var)
86            }
87
88            let depth = self.new_generic_tlint(stage.loc());
89            self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth.clone());
90            let depth = match &stage.inner {
91                spade_hir::expression::PipelineRefKind::Absolute(name) => {
92                    let key = TypedExpression::Name(name.inner.clone());
93                    let var = if !self.equations.contains_key(&key) {
94                        let var = self.new_generic_tlint(stage.loc());
95                        self.add_equation(key.clone(), var.clone());
96                        self.trace_stack.push(TraceStackEntry::PreAddingPipelineLabel(name.inner.clone(), var.clone()));
97                        var
98                    } else {
99                        let var = self.equations.get(&key).unwrap().clone();
100                        self.trace_stack.push(TraceStackEntry::RecoveringPipelineLabel(name.inner.clone(), var.clone()));
101                        var
102                    };
103                    // NOTE: Safe unwrap, depth is fresh
104                    self.unify(&depth, &var, ctx).unwrap()
105                },
106                spade_hir::expression::PipelineRefKind::Relative(expr) => {
107                    let expr_var = self.hir_type_expr_to_var(expr, generic_list)?;
108                    let total_offset = self.new_generic_tlint(stage.loc());
109                    self.add_constraint(
110                        total_offset.clone(),
111                        ConstraintExpr::Sum(
112                            Box::new(ConstraintExpr::Var(expr_var)),
113                            Box::new(ConstraintExpr::Var(self.get_pipeline_state(expression)?
114                                .current_stage_depth.clone()))
115                        ),
116                        stage.loc(),
117                        &total_offset,
118                        ConstraintSource::PipelineRegOffset{reg: expr.loc(), total: self.get_pipeline_state(expr)?.total_depth.loc()}
119                    );
120                    // Safe unwrap, depth is a fresh type var
121                    self.unify(&depth, &total_offset, ctx).unwrap()
122                },
123            };
124
125            let pipeline_state = self.pipeline_state
126                .as_ref()
127                .ok_or_else(|| diag_anyhow!(
128                    expression,
129                    "Expected a pipeline state"
130                ))?;
131            self.add_requirement(Requirement::ValidPipelineOffset {
132                definition_depth: pipeline_state
133                    .total_depth
134                    .clone(),
135                current_stage: pipeline_state.current_stage_depth.clone().nowhere(),
136                reference_offset: depth.at_loc(stage)
137            });
138
139            // Add an equation for the anonymous id
140            self.unify_expression_generic_error(
141                expression,
142                &TypedExpression::Name(name.clone().inner),
143                ctx
144            )?;
145        });
146        Ok(())
147    }
148
149    #[trace_typechecker]
150    #[tracing::instrument(level = "trace", skip_all)]
151    pub fn visit_int_literal(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
152        assuming_kind!(ExprKind::IntLiteral(value, kind) = &expression => {
153            let (t, _size) = match kind {
154                IntLiteralKind::Unsized => self.new_generic_number(expression.loc(), ctx),
155                IntLiteralKind::Signed(size) => {
156                    let (t, size_var) = self.new_split_generic_int(expression.loc(), ctx.symtab);
157                    // NOTE: Safe unwrap, we're unifying a generic int with a size
158                    self.unify(&size_var, &KnownType::Integer(size.to_bigint()).at_loc(expression), ctx).unwrap();
159                    (t, size_var)
160                },
161                IntLiteralKind::Unsigned(size) => {
162                    let (t, size_var) = self.new_split_generic_uint(expression.loc(), ctx.symtab);
163                    // NOTE: Safe unwrap, we're unifying a generic int with a size
164                    self.unify(&size_var, &KnownType::Integer(size.to_bigint()).at_loc(expression), ctx).unwrap();
165                    (t, size_var)
166                }
167            };
168            self.unify(&t, &expression.inner, ctx)
169                .into_diagnostic(expression.loc(), |diag, Tm{e: _, g: _got}| {
170                    diag
171                        .level(DiagnosticLevel::Bug)
172                        .message("Failed to unify integer literal with integer")
173                })?;
174            self.add_requirement(Requirement::FitsIntLiteral {
175                value: ConstantInt::Literal(value.clone()),
176                target_type: t.at_loc(expression)
177            });
178        });
179        Ok(())
180    }
181
182    #[trace_typechecker]
183    #[tracing::instrument(level = "trace", skip_all)]
184    pub fn visit_bool_literal(
185        &mut self,
186        expression: &Loc<Expression>,
187        ctx: &Context,
188    ) -> Result<()> {
189        assuming_kind!(ExprKind::BoolLiteral(_) = &expression => {
190            self.unify_expression_generic_error(expression, &t_bool(ctx.symtab).at_loc(expression), ctx)?;
191        });
192        Ok(())
193    }
194
195    #[trace_typechecker]
196    #[tracing::instrument(level = "trace", skip_all)]
197    pub fn visit_bit_literal(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
198        assuming_kind!(ExprKind::BitLiteral(_) = &expression => {
199            self.unify_expression_generic_error(expression, &t_bit(ctx.symtab).at_loc(expression), ctx)?;
200        });
201        Ok(())
202    }
203
204    #[trace_typechecker]
205    #[tracing::instrument(level = "trace", skip_all)]
206    pub fn visit_tuple_literal(
207        &mut self,
208        expression: &Loc<Expression>,
209        ctx: &Context,
210        generic_list: &GenericListToken,
211    ) -> Result<()> {
212        assuming_kind!(ExprKind::TupleLiteral(inner) = &expression => {
213            for expr in inner {
214                self.visit_expression(expr, ctx, generic_list)?;
215                // NOTE: safe unwrap, we know this expr has a type because we just visited
216            }
217
218            let mut inner_types = vec![];
219            for expr in inner {
220                let t = self.type_of(&TypedExpression::Id(expr.id)).unwrap();
221
222                inner_types.push(t);
223            }
224
225            self.unify_expression_generic_error(
226                expression,
227                &TypeVar::Known(expression.loc(), KnownType::Tuple, inner_types),
228                ctx
229            )?;
230        });
231        Ok(())
232    }
233
234    #[trace_typechecker]
235    #[tracing::instrument(level = "trace", skip_all)]
236    pub fn visit_tuple_index(
237        &mut self,
238        expression: &Loc<Expression>,
239        ctx: &Context,
240        generic_list: &GenericListToken,
241    ) -> Result<()> {
242        assuming_kind!(ExprKind::TupleIndex(tup, index) = &expression => {
243            self.visit_expression(tup, ctx, generic_list)?;
244            let t = self.type_of(&TypedExpression::Id(tup.id))?;
245
246            let inner_types = match t {
247                TypeVar::Known(_, KnownType::Tuple, inner) => inner,
248                ref t @ TypeVar::Known(ref other_source, _, _) => {
249                    return Err(Diagnostic::error(tup.loc(), "Attempt to use tuple indexing on non-tuple")
250                        .primary_label(format!("expected tuple, got {t}"))
251                        .secondary_label(index, "Because this is a tuple index")
252                        .secondary_label(other_source, format!("Type {t} inferred here"))
253                    );
254                }
255                TypeVar::Unknown(_, _, _, MetaType::Type | MetaType::Any) => {
256                    return Err(
257                        Diagnostic::error(tup.as_ref(), "Type of tuple indexee must be known at this point")
258                            .primary_label("The type of this must be known")
259                    )
260                }
261                TypeVar::Unknown(ref other_source, _, _, meta @ (MetaType::Uint | MetaType::Int | MetaType::Number | MetaType::Bool)) => {
262                    return Err(
263                        Diagnostic::error(tup.as_ref(), "Cannot use tuple indexing on a type level number")
264                            .primary_label("Tuple indexing on type level number")
265                        .secondary_label(other_source, format!("Meta-type {meta} inferred here"))
266                    )
267                }
268            };
269
270            if (index.inner as usize) < inner_types.len() {
271                let true_inner_type = inner_types[index.inner as usize].clone();
272                self.unify_expression_generic_error(
273                    expression,
274                    &true_inner_type,
275                    ctx
276                )?
277            } else {
278                return Err(Diagnostic::error(index, "Tuple index out of bounds")
279                    .primary_label(format!("Tuple only has {} elements", inner_types.len()))
280                    .note(format!("     Index: {}", index))
281                    .note(format!("Tuple size: {}", inner_types.len()))
282                );
283            }
284        });
285        Ok(())
286    }
287
288    #[trace_typechecker]
289    #[tracing::instrument(level = "trace", skip_all)]
290    pub fn visit_field_access(
291        &mut self,
292        expression: &Loc<Expression>,
293        ctx: &Context,
294        generic_list: &GenericListToken,
295    ) -> Result<()> {
296        assuming_kind!(ExprKind::FieldAccess(target, field) = &expression => {
297            self.visit_expression(target, ctx, generic_list)?;
298
299            let target_type = self.type_of(&TypedExpression::Id(target.id))?;
300            let self_type = self.type_of(&TypedExpression::Id(expression.id))?;
301
302            let requirement = Requirement::HasField {
303                target_type: target_type.at_loc(target),
304                field: field.clone(),
305                expr: self_type.at_loc(expression)
306            };
307
308            requirement.check_or_add(self, ctx)?;
309        });
310        Ok(())
311    }
312
313    #[trace_typechecker]
314    #[tracing::instrument(level = "trace", skip_all)]
315    pub fn visit_method_call(
316        &mut self,
317        expression: &Loc<Expression>,
318        ctx: &Context,
319        generic_list: &GenericListToken,
320    ) -> Result<()> {
321        assuming_kind!(ExprKind::MethodCall{call_kind, target, name, args, turbofish} = &expression => {
322            // NOTE: We don't visit_expression here as it is being added to the argument_list
323            // which we *do* visit
324            // self.visit_expression(target, ctx, generic_list)?;
325
326            let args_with_self = args.clone().map(|mut args| {
327                match &mut args {
328                    spade_hir::ArgumentList::Named(inner) => {
329                        inner.push(NamedArgument::Full(
330                            Identifier("self".to_string()).at_loc(target),
331                            target.as_ref().clone()
332                        ))
333                    },
334                    spade_hir::ArgumentList::Positional(list) => list.insert(0, target.as_ref().clone()),
335                };
336                args
337            });
338
339            self.visit_argument_list(&args_with_self, ctx, generic_list)?;
340
341            let target_type = self.type_of(&TypedExpression::Id(target.id))?;
342            let self_type = self.type_of(&TypedExpression::Id(expression.id))?;
343
344            let trait_list = if let TypeVar::Unknown(_, _, trait_list, MetaType::Type) = &target_type {
345                if !trait_list.inner.is_empty() {
346                    Some(trait_list.clone())
347                } else {
348                    None
349                }
350            } else {
351                None
352            };
353
354            let requirement = Requirement::HasMethod {
355                expr_id: expression.map_ref(|e| e.id),
356                target_type: target_type.at_loc(target),
357                trait_list,
358                method: name.clone(),
359                expr: self_type.at_loc(expression),
360                args: args_with_self,
361                turbofish: turbofish.clone(),
362                prev_generic_list: generic_list.clone(),
363                call_kind: call_kind.clone()
364            };
365
366            requirement.check_or_add(self, ctx)?
367        });
368        Ok(())
369    }
370
371    #[trace_typechecker]
372    #[tracing::instrument(level = "trace", skip_all)]
373    pub fn visit_array_literal(
374        &mut self,
375        expression: &Loc<Expression>,
376        ctx: &Context,
377        generic_list: &GenericListToken,
378    ) -> Result<()> {
379        assuming_kind!(ExprKind::ArrayLiteral(members) = &expression => {
380            for expr in members {
381                self.visit_expression(expr, ctx, generic_list)?;
382            }
383
384            // unify all elements in array pairwise, e.g. unify(0, 1), unify(1, 2), ...
385            for (l, r) in members.iter().zip(members.iter().skip(1)) {
386                self.unify(r, l, ctx)
387                    .into_diagnostic(r, |diag, Tm{e: expected, g: _got}| {
388                        diag.message(format!(
389                            "Array element type mismatch. Expected {}",
390                            expected
391                        ))
392                        .primary_label(format!("Expected {}", expected))
393                        .secondary_label(members.first().unwrap().loc(), "To match this".to_string())
394                    })?;
395            }
396
397            let inner_type = if members.is_empty() {
398                self.new_generic_type(expression.loc())
399            }
400            else {
401                members[0].get_type(self)?
402            };
403
404            let size_type = TypeVar::Known(expression.loc(), KnownType::Integer(members.len().to_bigint()), vec![]);
405            let result_type = TypeVar::array(
406                expression.loc(),
407                inner_type,
408                size_type,
409            );
410
411            self.unify_expression_generic_error(expression, &result_type, ctx)?;
412        });
413        Ok(())
414    }
415
416    pub fn visit_array_shorthand_literal(
417        &mut self,
418        expression: &Loc<Expression>,
419        ctx: &Context,
420        generic_list: &GenericListToken,
421    ) -> Result<()> {
422        assuming_kind!(ExprKind::ArrayShorthandLiteral(expr, amount) = &expression => {
423            self.visit_expression(expr, ctx, generic_list)?;
424
425
426            let inner_type = expr.get_type(self)?;
427            let size_type = self.visit_const_generic_with_id(amount, generic_list, ConstraintSource::ArraySize, ctx)?;
428            // Force the type to be a uint
429            let uint_type = self.new_generic_tluint(expression.loc());
430            self.unify(&size_type, &uint_type, ctx).into_default_diagnostic(expression.loc())?;
431
432            let result_type = TypeVar::array(expression.loc(), inner_type, size_type);
433
434            self.unify_expression_generic_error(expression, &result_type, ctx)?;
435        });
436        Ok(())
437    }
438
439    #[trace_typechecker]
440    #[tracing::instrument(level = "trace", skip_all)]
441    pub fn visit_create_ports(
442        &mut self,
443        expression: &Loc<Expression>,
444        ctx: &Context,
445        _generic_list: &GenericListToken,
446    ) -> Result<()> {
447        assuming_kind!(ExprKind::CreatePorts = &expression => {
448            let inner_type = self.new_generic_type(expression.loc());
449            let inverted = TypeVar::Known(expression.loc(), KnownType::Inverted, vec![inner_type.clone()]);
450            let compound = TypeVar::tuple(expression.loc(), vec![inner_type, inverted]);
451            self.unify_expression_generic_error(expression, &compound, ctx)?;
452        });
453        Ok(())
454    }
455
456    #[trace_typechecker]
457    #[tracing::instrument(level = "trace", skip_all)]
458    pub fn visit_index(
459        &mut self,
460        expression: &Loc<Expression>,
461        ctx: &Context,
462        generic_list: &GenericListToken,
463    ) -> Result<()> {
464        assuming_kind!(ExprKind::Index(target, index) = &expression => {
465            // Visit child nodes
466            self.visit_expression(target, ctx, generic_list)?;
467            self.visit_expression(index, ctx, generic_list)?;
468
469            // Add constraints
470            let inner_type = self.new_generic_type(expression.loc());
471
472            // Unify inner type with this expression
473            self.unify_expression_generic_error(
474                expression,
475                &inner_type,
476                ctx
477            )?;
478
479            let array_size = self.new_generic_tluint(expression.loc());
480            let (int_type, int_size) = self.new_split_generic_uint(index.loc(), ctx.symtab);
481
482            // NOTE[et]: Only used for size constraints of this exact type - this can be a
483            // requirement instead, that way we remove a lot of complexity! :D
484            self.add_constraint(
485                int_size,
486                bits_to_store(ce_var(&array_size) - ce_int(BigInt::one())),
487                index.loc(),
488                &int_type,
489                ConstraintSource::ArrayIndexing
490            );
491
492            self.unify(&index.inner, &int_type, ctx)
493                .into_diagnostic(index.as_ref(), |diag, Tm{e: _expected, g: got}| {
494
495                    diag.message(format!("Index must be an integer, got {}", got))
496                        .primary_label("Expected integer".to_string())
497                })?;
498
499            let array_type = TypeVar::array(
500                expression.loc(),
501                expression.get_type(self)?,
502                array_size.clone()
503            );
504            self.add_requirement(Requirement::ArrayIndexeeIsNonZero {
505                index: index.loc(),
506                array: array_type.clone().at_loc(target),
507                array_size: array_size.clone().at_loc(index)
508            });
509            self.unify(&target.inner, &array_type, ctx)
510                .into_diagnostic(target.as_ref(), |diag, Tm{e: _expected, g: got}| {
511                    diag
512                        .message(format!("Index target must be an array, got {}", got))
513                        .primary_label("Expected array".to_string())
514                })?;
515        });
516        Ok(())
517    }
518
519    #[trace_typechecker]
520    #[tracing::instrument(level = "trace", skip_all)]
521    pub fn visit_range_index(
522        &mut self,
523        expression: &Loc<Expression>,
524        ctx: &Context,
525        generic_list: &GenericListToken,
526    ) -> Result<()> {
527        assuming_kind!(ExprKind::RangeIndex{
528            target,
529            ref start,
530            ref end,
531        } = &expression => {
532            self.visit_expression(target, ctx, generic_list)?;
533            // Add constraints
534            let inner_type = self.new_generic_type(target.loc());
535
536            let start_var = self.visit_const_generic_with_id(start, generic_list, ConstraintSource::RangeIndex, ctx)?;
537            let end_var = self.visit_const_generic_with_id(end, generic_list, ConstraintSource::RangeIndex, ctx)?;
538
539            let in_array_size = self.new_generic_tluint(target.loc());
540            let in_array_type = TypeVar::array(expression.loc(), inner_type.clone(), in_array_size.clone());
541            let out_array_size = self.new_generic_tluint(target.loc());
542            let out_array_type = TypeVar::array(expression.loc(), inner_type.clone(), out_array_size.clone());
543
544            let out_size_constraint = ConstraintExpr::Var(end_var.clone()) - ConstraintExpr::Var(start_var.clone());
545            self.add_constraint(out_array_size, out_size_constraint, expression.loc(), &out_array_type, ConstraintSource::RangeIndex);
546
547            self.add_requirement(Requirement::RangeIndexEndAfterStart { expr: expression.loc(), start: start_var.clone().at_loc(&start), end: end_var.clone().at_loc(end) });
548            self.add_requirement(Requirement::RangeIndexInArray { index: end_var.at_loc(end), size: in_array_size.at_loc(&target.loc()) });
549
550            self.unify(&expression.inner, &out_array_type, ctx)
551                .into_default_diagnostic(expression)?;
552
553
554            self.unify(&target.inner, &in_array_type, ctx)
555                .into_diagnostic(target.as_ref(), |diag, Tm{e: _expected, g: got}| {
556                    diag
557                        .message(format!("Index target must be an array, got {}", got))
558                        .primary_label("Expected array".to_string())
559                })?;
560        });
561        Ok(())
562    }
563
564    #[trace_typechecker]
565    #[tracing::instrument(level = "trace", skip_all)]
566    pub fn visit_block_expr(
567        &mut self,
568        expression: &Loc<Expression>,
569        ctx: &Context,
570        generic_list: &GenericListToken,
571    ) -> Result<()> {
572        assuming_kind!(ExprKind::Block(block) = expression => {
573            self.visit_block(block, ctx, generic_list)?;
574
575            if let Some(result) = &block.result {
576                // Unify the return type of the block with the type of this expression
577                self.unify(&expression.inner, &result.inner, ctx)
578                    // NOTE: We could be more specific about this error specifying
579                    // that the type of the block must match the return type, though
580                    // that might just be spammy.
581                    .into_default_diagnostic(result)?;
582            } else {
583                // Block without return value. Unify with unit type.
584                self.unify(&expression.inner, &TypeVar::unit(expression.loc()), ctx)
585                    .into_diagnostic(Loc::nowhere(()), |err, Tm{g: _, e: _}| {
586                        diag_anyhow!(
587                            Loc::nowhere(()),
588                            "This error shouldn't be possible: {err:?}"
589                        )})?;
590            }
591        });
592        Ok(())
593    }
594
595    #[trace_typechecker]
596    #[tracing::instrument(level = "trace", skip_all)]
597    pub fn visit_if(
598        &mut self,
599        expression: &Loc<Expression>,
600        ctx: &Context,
601        generic_list: &GenericListToken,
602    ) -> Result<()> {
603        assuming_kind!(ExprKind::If(cond, on_true, on_false) = &expression => {
604            self.visit_expression(cond, ctx, generic_list)?;
605            self.visit_expression(on_true, ctx, generic_list)?;
606            self.visit_expression(on_false, ctx, generic_list)?;
607
608            self.unify(&cond.inner, &t_bool(ctx.symtab).at_loc(cond), ctx)
609                .into_diagnostic(cond.as_ref(), |diag, Tm{e: _expected, g: got}| {
610                    diag.
611                        message(format!("If condition must be a bool, got {}", got))
612                        .primary_label("Expected boolean")
613                })?;
614            self.unify(&on_false.inner, &on_true.inner, ctx)
615                .into_diagnostic(on_false.as_ref(), |diag, Tm{e: expected, g: got}| {
616                    diag.message("If branches have incompatible type")
617                        .primary_label(format!("But this has type {got}"))
618                        .secondary_label(on_true.as_ref(), format!("This branch has type {expected}"))
619                })?;
620            self.unify(expression, &on_false.inner, ctx)
621                .into_default_diagnostic(expression)?;
622        });
623        Ok(())
624    }
625
626    #[trace_typechecker]
627    #[tracing::instrument(level = "trace", skip_all)]
628    pub fn visit_match(
629        &mut self,
630        expression: &Loc<Expression>,
631        ctx: &Context,
632        generic_list: &GenericListToken,
633    ) -> Result<()> {
634        assuming_kind!(ExprKind::Match(cond, branches) = &expression => {
635            self.visit_expression(cond, ctx, generic_list)?;
636
637            for (i, (pattern, result)) in branches.iter().enumerate() {
638                self.visit_pattern(pattern, ctx, generic_list)?;
639
640                self.unify(pattern, &cond.inner, ctx)
641                    .into_default_diagnostic(pattern)?;
642
643                self.visit_expression(result, ctx, generic_list)?;
644
645                if i != 0 {
646                    self.unify(&branches[0].1, result, ctx).into_diagnostic(
647                        result,
648                        |diag, Tm{e: expected, g: got}| {
649                            diag.message("Match branches have incompatible type")
650                                .primary_label(format!("This branch has type {got}"))
651                                .secondary_label(&branches[0].1, format!("But this one has type {expected}"))
652                        }
653                    )?;
654                }
655            }
656
657            assert!(
658                !branches.is_empty(),
659                "Empty match statements should be checked by ast lowering"
660            );
661
662            self.unify_expression_generic_error(&branches[0].1, expression, ctx)?;
663        });
664        Ok(())
665    }
666
667    #[trace_typechecker]
668    #[tracing::instrument(level = "trace", skip_all)]
669    pub fn visit_binary_operator(
670        &mut self,
671        expression: &Loc<Expression>,
672        ctx: &Context,
673        generic_list: &GenericListToken,
674    ) -> Result<()> {
675        assuming_kind!(ExprKind::BinaryOperator(lhs, op, rhs) = &expression => {
676            self.visit_expression(lhs, ctx, generic_list)?;
677            self.visit_expression(rhs, ctx, generic_list)?;
678            match op.inner {
679                BinaryOperator::Add
680                | BinaryOperator::Sub | BinaryOperator::Mul if self.use_wordlenght_inference => {
681                    let lhs_t = self.new_generic_int(expression.loc(), ctx.symtab);
682                    self.unify_expression_generic_error(lhs, &lhs_t, ctx)?;
683                    let rhs_t = self.new_generic_int(expression.loc(), ctx.symtab);
684                    self.unify_expression_generic_error(rhs, &rhs_t, ctx)?;
685                    let result_t = self.new_generic_int(expression.loc(), ctx.symtab);
686                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
687                }
688                BinaryOperator::Add
689                | BinaryOperator::Sub => {
690                    let (in_t, lhs_size) = self.new_generic_number(expression.loc(), ctx);
691                    let (result_t, result_size) = self.new_generic_number(expression.loc(), ctx);
692
693                    self.add_constraint(
694                        result_size.clone(),
695                        ce_var(&lhs_size) + ce_int(BigInt::one()),
696                        expression.loc(),
697                        &result_t,
698                        ConstraintSource::AdditionOutput
699                    );
700                    self.add_constraint(
701                        lhs_size.clone(),
702                        ce_var(&result_size) + -ce_int(BigInt::one()),
703                        lhs.loc(),
704                        &in_t,
705                        ConstraintSource::AdditionOutput
706                    );
707
708                    self.unify_expression_generic_error(lhs, &in_t, ctx)?;
709                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
710                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
711
712                    self.add_requirement(Requirement::SharedBase(vec![
713                        in_t.at_loc(lhs),
714                        result_t.at_loc(expression)
715                    ]));
716
717                }
718                BinaryOperator::Mul => {
719                    let (lhs_t, lhs_size) = self.new_generic_number(expression.loc(), ctx);
720                    let (rhs_t, rhs_size) = self.new_generic_number(expression.loc(), ctx);
721                    let (result_t, result_size) = self.new_generic_number(expression.loc(), ctx);
722
723                    // Result size is sum of input sizes
724                    self.add_constraint(
725                        result_size.clone(),
726                        ce_var(&lhs_size) + ce_var(&rhs_size),
727                        expression.loc(),
728                        &result_t,
729                        ConstraintSource::MultOutput
730                    );
731                    self.add_constraint(
732                        lhs_size.clone(),
733                        ce_var(&result_size) + -ce_var(&rhs_size),
734                        lhs.loc(),
735                        &lhs_t,
736                        ConstraintSource::MultOutput
737                    );
738                    self.add_constraint(rhs_size.clone(),
739                        ce_var(&result_size) + -ce_var(&lhs_size),
740                        rhs.loc(),
741                        &rhs_t
742                        , ConstraintSource::MultOutput
743                    );
744
745                    self.unify_expression_generic_error(lhs, &lhs_t, ctx)?;
746                    self.unify_expression_generic_error(rhs, &rhs_t, ctx)?;
747                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
748
749                    self.add_requirement(Requirement::SharedBase(vec![
750                        lhs_t.at_loc(lhs),
751                        rhs_t.at_loc(rhs),
752                        result_t.at_loc(expression)
753                    ]));
754                }
755                // Division, being integer division has the same width out as in
756                BinaryOperator::Div | BinaryOperator::Mod => {
757                    let (int_type, _size) = self.new_generic_number(expression.loc(), ctx);
758
759                    self.unify_expression_generic_error(lhs, &int_type, ctx)?;
760                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
761                    self.unify_expression_generic_error(expression, &rhs.inner, ctx)?;
762                },
763                // Shift operators have the same width in as they do out
764                BinaryOperator::LeftShift
765                | BinaryOperator::BitwiseAnd
766                | BinaryOperator::BitwiseXor
767                | BinaryOperator::BitwiseOr
768                | BinaryOperator::ArithmeticRightShift
769                | BinaryOperator::RightShift => {
770                    let (int_type, _size) = self.new_generic_number(expression.loc(), ctx);
771
772                    // FIXME: Make generic over types that can be bitmanipulated
773                    self.unify_expression_generic_error(lhs, &int_type, ctx)?;
774                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
775                    self.unify_expression_generic_error(expression, &rhs.inner, ctx)?;
776                }
777                BinaryOperator::Eq
778                | BinaryOperator::NotEq
779                | BinaryOperator::Gt
780                | BinaryOperator::Lt
781                | BinaryOperator::Ge
782                | BinaryOperator::Le => {
783                    let (base, _size) = self.new_generic_number(expression.loc(), ctx);
784                    // FIXME: Make generic over types that can be compared
785                    self.unify_expression_generic_error(lhs, &base, ctx)?;
786                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
787                    self.unify_expression_generic_error(expression, &t_bool(ctx.symtab).at_loc(expression), ctx)?;
788                }
789                BinaryOperator::LogicalAnd
790                | BinaryOperator::LogicalOr
791                | BinaryOperator::LogicalXor => {
792                    self.unify_expression_generic_error(lhs, &t_bool(ctx.symtab).at_loc(expression), ctx)?;
793                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
794
795                    self.unify_expression_generic_error(expression, &t_bool(ctx.symtab).at_loc(expression), ctx)?;
796                }
797            }
798        });
799        Ok(())
800    }
801
802    #[trace_typechecker]
803    #[tracing::instrument(level = "trace", skip_all)]
804    pub fn visit_unary_operator(
805        &mut self,
806        expression: &Loc<Expression>,
807        ctx: &Context,
808        generic_list: &GenericListToken,
809    ) -> Result<()> {
810        assuming_kind!(ExprKind::UnaryOperator(op, operand) = &expression => {
811            self.visit_expression(operand, ctx, generic_list)?;
812            match &op.inner {
813                UnaryOperator::Sub => {
814                    let int_type = self.new_generic_int(expression.loc(), ctx.symtab);
815                    self.unify_expression_generic_error(operand, &int_type, ctx)?;
816                    self.unify_expression_generic_error(expression, &int_type, ctx)?
817                }
818                UnaryOperator::BitwiseNot => {
819                    let (number_type, _) = self.new_generic_number(expression.loc(), ctx);
820                    self.unify_expression_generic_error(operand, &number_type, ctx)?;
821                    self.unify_expression_generic_error(expression, &number_type, ctx)?
822                }
823                UnaryOperator::Not => {
824                    self.unify_expression_generic_error(operand, &t_bool(ctx.symtab).at_loc(expression), ctx)?;
825                    self.unify_expression_generic_error(expression, &t_bool(ctx.symtab).at_loc(expression), ctx)?
826                }
827                UnaryOperator::Dereference => {
828                    let result_type = self.new_generic_type(expression.loc());
829                    let reference_type = TypeVar::wire(expression.loc(), result_type.clone());
830                    self.unify_expression_generic_error(operand, &reference_type, ctx)?;
831                    self.unify_expression_generic_error(expression, &result_type, ctx)?
832                }
833                UnaryOperator::Reference => {
834                    let result_type = self.new_generic_type(expression.loc());
835                    let reference_type = TypeVar::wire(expression.loc(), result_type.clone());
836                    self.unify_expression_generic_error(operand, &result_type, ctx)?;
837                    self.unify_expression_generic_error(expression, &reference_type, ctx)?
838                }
839            }
840        });
841        Ok(())
842    }
843}