spade_typeinference/
expression.rs

1use num::{BigInt, One};
2use spade_common::location_info::{Loc, WithLocation};
3use spade_common::name::Identifier;
4use spade_common::num_ext::InfallibleToBigInt;
5use spade_diagnostics::diagnostic::DiagnosticLevel;
6use spade_diagnostics::{diag_anyhow, Diagnostic};
7use spade_hir::expression::{BinaryOperator, IntLiteralKind, NamedArgument, UnaryOperator};
8use spade_hir::{ExprKind, Expression};
9use spade_macros::trace_typechecker;
10use spade_types::meta_types::MetaType;
11use spade_types::KnownType;
12
13use crate::constraints::{bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource};
14use crate::equation::{TypeVar, TypedExpression};
15use crate::error::{TypeMismatch as Tm, UnificationErrorExt};
16use crate::requirements::{ConstantInt, Requirement};
17use crate::{Context, GenericListToken, HasType, Result, TraceStackEntry, TypeState};
18
19macro_rules! assuming_kind {
20    ($pattern:pat = $expr:expr => $block:block) => {
21        if let $pattern = &$expr.inner.kind {
22            $block
23        } else {
24            panic!("Incorrect assumption about expression kind")
25        };
26    };
27}
28
29impl TypeState {
30    #[trace_typechecker]
31    #[tracing::instrument(level = "trace", skip_all)]
32    pub fn visit_identifier(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
33        assuming_kind!(ExprKind::Identifier(ident) = &expression => {
34            // Add an equation for the anonymous id
35            self.unify_expression_generic_error(
36                expression,
37                &TypedExpression::Name(ident.clone()),
38                ctx
39            )?;
40        });
41        Ok(())
42    }
43
44    #[trace_typechecker]
45    #[tracing::instrument(level = "trace", skip_all)]
46    pub fn visit_type_level_integer(
47        &mut self,
48        expression: &Loc<Expression>,
49        generic_list: &GenericListToken,
50        ctx: &Context,
51    ) -> Result<()> {
52        assuming_kind!(ExprKind::TypeLevelInteger(value) = &expression => {
53            let (t, _size) = self.new_generic_number(expression.loc(), ctx);
54            self.unify(&t, &expression.inner, ctx)
55                .into_diagnostic(expression.loc(), |diag, _tm| {
56                    diag
57                        .level(DiagnosticLevel::Bug)
58                        .message("Failed to unify integer literal with integer")
59                }, self)?;
60            let generic_list = self
61                .get_generic_list(generic_list)
62                .ok_or_else(|| {
63                    diag_anyhow!(expression, "Found no generic list here")
64                })?;
65            let generic = generic_list
66                .get(value)
67                .ok_or_else(|| {
68                    diag_anyhow!(expression, "Found no entry for {value:?} in generic list. It has {generic_list:?}")
69                })?;
70            self.add_requirement(
71                Requirement::FitsIntLiteral {
72                    value: ConstantInt::Generic(generic.clone()),
73                    target_type: t.at_loc(expression)
74                }
75            )
76        });
77        Ok(())
78    }
79
80    #[tracing::instrument(level = "trace", skip_all)]
81    pub fn visit_pipeline_ref(
82        &mut self,
83        expression: &Loc<Expression>,
84        generic_list: &GenericListToken,
85        ctx: &Context,
86    ) -> Result<()> {
87        assuming_kind!(ExprKind::PipelineRef{stage, name, declares_name, depth_typeexpr_id} = &expression => {
88            // If this reference declares the referenced name, add a new equation
89            if *declares_name {
90                let new_var = self.new_generic_type(expression.loc());
91                self.add_equation(TypedExpression::Name(name.clone().inner), new_var)
92            }
93
94            let depth = self.new_generic_tlint(stage.loc());
95            self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth.clone());
96            let depth = match &stage.inner {
97                spade_hir::expression::PipelineRefKind::Absolute(name) => {
98                    let key = TypedExpression::Name(name.inner.clone());
99                    let var = if !self.equations.contains_key(&key) {
100                        let var = self.new_generic_tlint(stage.loc());
101                        self.add_equation(key.clone(), var.clone());
102                        self.trace_stack.push(TraceStackEntry::PreAddingPipelineLabel(name.inner.clone(), var.debug_resolve(self)));
103                        var
104                    } else {
105                        let var = self.equations.get(&key).unwrap().clone();
106                        self.trace_stack.push(TraceStackEntry::RecoveringPipelineLabel(name.inner.clone(), var.debug_resolve(self)));
107                        var
108                    };
109                    // NOTE: Safe unwrap, depth is fresh
110                    self.unify(&depth, &var, ctx).unwrap()
111                },
112                spade_hir::expression::PipelineRefKind::Relative(expr) => {
113                    let expr_var = self.hir_type_expr_to_var(expr, generic_list)?;
114                    let total_offset = self.new_generic_tlint(stage.loc());
115                    self.add_constraint(
116                        total_offset.clone(),
117                        ConstraintExpr::Sum(
118                            Box::new(ConstraintExpr::Var(expr_var)),
119                            Box::new(ConstraintExpr::Var(self.get_pipeline_state(expression)?
120                                .current_stage_depth.clone()))
121                        ),
122                        stage.loc(),
123                        &total_offset,
124                        ConstraintSource::PipelineRegOffset{reg: expr.loc(), total: self.get_pipeline_state(expr)?.total_depth.loc()}
125                    );
126                    // Safe unwrap, depth is a fresh type var
127                    self.unify(&depth, &total_offset, ctx).unwrap()
128                },
129            };
130
131            let pipeline_state = self.pipeline_state
132                .as_ref()
133                .ok_or_else(|| diag_anyhow!(
134                    expression,
135                    "Expected a pipeline state"
136                ))?;
137            self.add_requirement(Requirement::ValidPipelineOffset {
138                definition_depth: pipeline_state
139                    .total_depth
140                    .clone(),
141                current_stage: pipeline_state.current_stage_depth.clone().nowhere(),
142                reference_offset: depth.at_loc(stage)
143            });
144
145            // Add an equation for the anonymous id
146            self.unify_expression_generic_error(
147                expression,
148                &TypedExpression::Name(name.clone().inner),
149                ctx
150            )?;
151        });
152        Ok(())
153    }
154
155    #[trace_typechecker]
156    #[tracing::instrument(level = "trace", skip_all)]
157    pub fn visit_int_literal(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
158        assuming_kind!(ExprKind::IntLiteral(value, kind) = &expression => {
159            let (t, _size) = match kind {
160                IntLiteralKind::Unsized => self.new_generic_number(expression.loc(), ctx),
161                IntLiteralKind::Signed(size) => {
162                    let (t, size_var) = self.new_split_generic_int(expression.loc(), ctx.symtab);
163                    // NOTE: Safe unwrap, we're unifying a generic int with a size
164                    size_var
165                        .unify_with(&TypeVar::Known(
166                            expression.loc(),
167                            KnownType::Integer(size.to_bigint()),
168                            vec![]).insert(self),
169                            self
170                        )
171                        .commit(self, ctx)
172                        .unwrap();
173                    (t, size_var)
174                },
175                IntLiteralKind::Unsigned(size) => {
176                    let (t, size_var) = self.new_split_generic_uint(expression.loc(), ctx.symtab);
177                    // NOTE: Safe unwrap, we're unifying a generic int with a size
178                    size_var
179                        .unify_with(&self.new_concrete_int(size.clone(), expression.loc()), self)
180                        .commit(self, ctx)
181                        .unwrap();
182                    (t, size_var)
183                }
184            };
185            self.unify(&t, &expression.inner, ctx)
186                .into_diagnostic(expression.loc(), |diag, Tm{e: _, g: _got}| {
187                    diag
188                        .level(DiagnosticLevel::Bug)
189                        .message("Failed to unify integer literal with integer")
190                }, self)?;
191            self.add_requirement(Requirement::FitsIntLiteral {
192                value: ConstantInt::Literal(value.clone()),
193                target_type: t.at_loc(expression)
194            });
195        });
196        Ok(())
197    }
198
199    #[trace_typechecker]
200    #[tracing::instrument(level = "trace", skip_all)]
201    pub fn visit_bool_literal(
202        &mut self,
203        expression: &Loc<Expression>,
204        ctx: &Context,
205    ) -> Result<()> {
206        assuming_kind!(ExprKind::BoolLiteral(_) = &expression => {
207            expression
208                .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
209                .commit(self, ctx)
210                .into_default_diagnostic(expression, self)?;
211        });
212        Ok(())
213    }
214
215    #[trace_typechecker]
216    #[tracing::instrument(level = "trace", skip_all)]
217    pub fn visit_tri_literal(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
218        assuming_kind!(ExprKind::TriLiteral(_) = &expression => {
219            expression
220                .unify_with(&self.t_tri(expression.loc(), ctx.symtab), self)
221                .commit(self, ctx)
222                .into_default_diagnostic(expression, self)?
223        });
224        Ok(())
225    }
226
227    #[trace_typechecker]
228    #[tracing::instrument(level = "trace", skip_all)]
229    pub fn visit_tuple_literal(
230        &mut self,
231        expression: &Loc<Expression>,
232        ctx: &Context,
233        generic_list: &GenericListToken,
234    ) -> Result<()> {
235        assuming_kind!(ExprKind::TupleLiteral(inner) = &expression => {
236            for expr in inner {
237                self.visit_expression(expr, ctx, generic_list);
238                // NOTE: safe unwrap, we know this expr has a type because we just visited
239            }
240
241            let mut inner_types = vec![];
242            for expr in inner {
243                let t = self.type_of(&TypedExpression::Id(expr.id));
244
245                inner_types.push(t);
246            }
247
248            expression
249                .unify_with(
250                    &TypeVar::Known(expression.loc(), KnownType::Tuple, inner_types).insert(self),
251                    self
252                )
253                .commit(self, ctx)
254                .into_default_diagnostic(expression, self)?
255        });
256        Ok(())
257    }
258
259    #[trace_typechecker]
260    #[tracing::instrument(level = "trace", skip_all)]
261    pub fn visit_tuple_index(
262        &mut self,
263        expression: &Loc<Expression>,
264        ctx: &Context,
265        generic_list: &GenericListToken,
266    ) -> Result<()> {
267        assuming_kind!(ExprKind::TupleIndex(tup, index) = &expression => {
268            self.visit_expression(tup, ctx, generic_list);
269            let t_id = self.type_of(&TypedExpression::Id(tup.id));
270
271            let inner_types = match t_id.resolve(self) {
272                TypeVar::Known(_, KnownType::Tuple, inner) => inner,
273                t @ TypeVar::Known(ref other_source, _, _) => {
274                    return Err(Diagnostic::error(tup.loc(), "Attempt to use tuple indexing on non-tuple")
275                        .primary_label(format!("expected tuple, got {t}", t = t.display(self)))
276                        .secondary_label(index, "Because this is a tuple index")
277                        .secondary_label(other_source, format!("Type {t} inferred here", t = t.display(self)))
278                    );
279                }
280                TypeVar::Unknown(_, _, _, MetaType::Type | MetaType::Any) => {
281                    return Err(
282                        Diagnostic::error(tup.as_ref(), "Type of tuple indexee must be known at this point")
283                            .primary_label("The type of this must be known")
284                    )
285                }
286                TypeVar::Unknown(ref other_source, _, _, meta @ (MetaType::Uint | MetaType::Int | MetaType::Number | MetaType::Bool | MetaType::Str)) => {
287                    return Err(
288                        Diagnostic::error(tup.as_ref(), "Cannot use tuple indexing on a type level number")
289                            .primary_label("Tuple indexing on type level number")
290                        .secondary_label(other_source, format!("Meta-type {meta} inferred here"))
291                    )
292                }
293            };
294
295            if (index.inner as usize) < inner_types.len() {
296                let true_inner_type = inner_types[index.inner as usize].clone();
297                self.unify_expression_generic_error(
298                    expression,
299                    &true_inner_type,
300                    ctx
301                )?
302            } else {
303                return Err(Diagnostic::error(index, "Tuple index out of bounds")
304                    .primary_label(format!("Tuple only has {} elements", inner_types.len()))
305                    .note(format!("     Index: {}", index))
306                    .note(format!("Tuple size: {}", inner_types.len()))
307                );
308            }
309        });
310        Ok(())
311    }
312
313    #[trace_typechecker]
314    #[tracing::instrument(level = "trace", skip_all)]
315    pub fn visit_field_access(
316        &mut self,
317        expression: &Loc<Expression>,
318        ctx: &Context,
319        generic_list: &GenericListToken,
320    ) -> Result<()> {
321        assuming_kind!(ExprKind::FieldAccess(target, field) = &expression => {
322            self.visit_expression(target, ctx, generic_list);
323
324            let target_type = self.type_of(&TypedExpression::Id(target.id));
325            let self_type = self.type_of(&TypedExpression::Id(expression.id));
326
327            let requirement = Requirement::HasField {
328                target_type: target_type.at_loc(target),
329                field: field.clone(),
330                expr: self_type.at_loc(expression)
331            };
332
333            requirement.check_or_add(self, ctx)?;
334        });
335        Ok(())
336    }
337
338    #[trace_typechecker]
339    #[tracing::instrument(level = "trace", skip_all)]
340    pub fn visit_method_call(
341        &mut self,
342        expression: &Loc<Expression>,
343        ctx: &Context,
344        generic_list: &GenericListToken,
345    ) -> Result<()> {
346        assuming_kind!(ExprKind::MethodCall{call_kind, target, name, args, turbofish, safety: _} = &expression => {
347            // NOTE: We don't visit_expression here as it is being added to the argument_list
348            // which we *do* visit
349            // self.visit_expression(target, ctx, generic_list)?;
350
351            let args_with_self = args.clone().map(|mut args| {
352                match &mut args {
353                    spade_hir::ArgumentList::Named(inner) => {
354                        inner.push(NamedArgument::Full(
355                            Identifier("self".to_string()).at_loc(target),
356                            target.as_ref().clone()
357                        ))
358                    },
359                    spade_hir::ArgumentList::Positional(list) => list.insert(0, target.as_ref().clone()),
360                };
361                args
362            });
363
364            self.visit_argument_list(&args_with_self, ctx, generic_list)?;
365
366            let target_type = self.type_of(&TypedExpression::Id(target.id));
367            let self_type = self.type_of(&TypedExpression::Id(expression.id));
368
369            let trait_list = if let TypeVar::Unknown(_, _, trait_list, MetaType::Type) = &target_type.resolve(self) {
370                if !trait_list.inner.is_empty() {
371                    Some(trait_list.clone())
372                } else {
373                    None
374                }
375            } else {
376                None
377            };
378
379            let requirement = Requirement::HasMethod {
380                expr_id: expression.map_ref(|e| e.id),
381                target_type: target_type.at_loc(target),
382                trait_list,
383                method: name.clone(),
384                expr: self_type.at_loc(expression),
385                args: args_with_self,
386                turbofish: turbofish.clone(),
387                prev_generic_list: generic_list.clone(),
388                call_kind: call_kind.clone()
389            };
390
391            requirement.check_or_add(self, ctx)?
392        });
393        Ok(())
394    }
395
396    #[trace_typechecker]
397    #[tracing::instrument(level = "trace", skip_all)]
398    pub fn visit_array_literal(
399        &mut self,
400        expression: &Loc<Expression>,
401        ctx: &Context,
402        generic_list: &GenericListToken,
403    ) -> Result<()> {
404        assuming_kind!(ExprKind::ArrayLiteral(members) = &expression => {
405            for expr in members {
406                self.visit_expression(expr, ctx, generic_list);
407            }
408
409            // unify all elements in array pairwise, e.g. unify(0, 1), unify(1, 2), ...
410            for (l, r) in members.iter().zip(members.iter().skip(1)) {
411                self.unify(r, l, ctx)
412                    .into_diagnostic(r, |diag, Tm{e: expected, g: _got}| {
413                        let expected = expected.display(self);
414                        diag.message(format!(
415                            "Array element type mismatch. Expected {}",
416                            expected
417                        ))
418                        .primary_label(format!("Expected {}", expected))
419                        .secondary_label(members.first().unwrap().loc(), "To match this".to_string())
420                    }, self)?;
421            }
422
423            let inner_type = if members.is_empty() {
424                self.new_generic_type(expression.loc())
425            }
426            else {
427                members[0].get_type(self)
428            };
429
430            let size_type = TypeVar::Known(expression.loc(), KnownType::Integer(members.len().to_bigint()), vec![]).insert(self);
431            let result_type = TypeVar::array(
432                expression.loc(),
433                inner_type,
434                size_type,
435            ).insert(self);
436
437            self.unify_expression_generic_error(expression, &result_type, ctx)?;
438        });
439        Ok(())
440    }
441
442    pub fn visit_array_shorthand_literal(
443        &mut self,
444        expression: &Loc<Expression>,
445        ctx: &Context,
446        generic_list: &GenericListToken,
447    ) -> Result<()> {
448        assuming_kind!(ExprKind::ArrayShorthandLiteral(expr, amount) = &expression => {
449            self.visit_expression(expr, ctx, generic_list);
450
451
452            let inner_type = expr.get_type(self);
453            let size_type = self.visit_const_generic_with_id(amount, generic_list, ConstraintSource::ArraySize, ctx)?;
454            // Force the type to be a uint
455            let uint_type = self.new_generic_tluint(expression.loc());
456            self.unify(&size_type, &uint_type, ctx).into_default_diagnostic(expression.loc(), self)?;
457
458            let result_type = TypeVar::array(expression.loc(), inner_type, size_type).insert(self);
459
460            self.unify_expression_generic_error(expression, &result_type, ctx)?;
461        });
462        Ok(())
463    }
464
465    #[trace_typechecker]
466    #[tracing::instrument(level = "trace", skip_all)]
467    pub fn visit_create_ports(
468        &mut self,
469        expression: &Loc<Expression>,
470        ctx: &Context,
471        _generic_list: &GenericListToken,
472    ) -> Result<()> {
473        assuming_kind!(ExprKind::CreatePorts = &expression => {
474            let inner_type = self.new_generic_type(expression.loc());
475            let inverted = TypeVar::Known(expression.loc(), KnownType::Inverted, vec![inner_type.clone()]).insert(self);
476            let compound = TypeVar::tuple(expression.loc(), vec![inner_type, inverted]).insert(self);
477            self.unify_expression_generic_error(expression, &compound, ctx)?;
478        });
479        Ok(())
480    }
481
482    #[trace_typechecker]
483    #[tracing::instrument(level = "trace", skip_all)]
484    pub fn visit_index(
485        &mut self,
486        expression: &Loc<Expression>,
487        ctx: &Context,
488        generic_list: &GenericListToken,
489    ) -> Result<()> {
490        assuming_kind!(ExprKind::Index(target, index) = &expression => {
491            // Visit child nodes
492            self.visit_expression(target, ctx, generic_list);
493            self.visit_expression(index, ctx, generic_list);
494
495            // Add constraints
496            let inner_type = self.new_generic_type(expression.loc());
497
498            // Unify inner type with this expression
499            self.unify_expression_generic_error(
500                expression,
501                &inner_type,
502                ctx
503            )?;
504
505            let array_size = self.new_generic_tluint(expression.loc());
506            let (int_type, int_size) = self.new_split_generic_uint(index.loc(), ctx.symtab);
507
508            // NOTE[et]: Only used for size constraints of this exact type - this can be a
509            // requirement instead, that way we remove a lot of complexity! :D
510            self.add_constraint(
511                int_size,
512                bits_to_store(ce_var(&array_size) - ce_int(BigInt::one())),
513                index.loc(),
514                &int_type,
515                ConstraintSource::ArrayIndexing
516            );
517
518            self.unify(&index.inner, &int_type, ctx)
519                .into_diagnostic(index.as_ref(), |diag, Tm{e: _expected, g: got}| {
520                    let got = got.display(self);
521                    diag.message(format!("Index must be an integer, got {}", got))
522                        .primary_label("Expected integer".to_string())
523                }, self)?;
524
525            let array_type = TypeVar::array(
526                expression.loc(),
527                expression.get_type(self),
528                array_size.clone()
529            ).insert(self);
530            self.add_requirement(Requirement::ArrayIndexeeIsNonZero {
531                index: index.loc(),
532                array: array_type.clone().at_loc(target),
533                array_size: array_size.clone().at_loc(index)
534            });
535            self.unify(&target.inner, &array_type, ctx)
536                .into_diagnostic(target.as_ref(), |diag, Tm{e: _expected, g: got}| {
537                    let got = got.display(self);
538                    diag
539                        .message(format!("Index target must be an array, got {}", got))
540                        .primary_label("Expected array".to_string())
541                }, self)?;
542        });
543        Ok(())
544    }
545
546    #[trace_typechecker]
547    #[tracing::instrument(level = "trace", skip_all)]
548    pub fn visit_range_index(
549        &mut self,
550        expression: &Loc<Expression>,
551        ctx: &Context,
552        generic_list: &GenericListToken,
553    ) -> Result<()> {
554        assuming_kind!(ExprKind::RangeIndex{
555            target,
556            ref start,
557            ref end,
558        } = &expression => {
559            self.visit_expression(target, ctx, generic_list);
560            // Add constraints
561            let inner_type = self.new_generic_type(target.loc());
562
563            let start_var = self.visit_const_generic_with_id(start, generic_list, ConstraintSource::RangeIndex, ctx)?;
564            let end_var = self.visit_const_generic_with_id(end, generic_list, ConstraintSource::RangeIndex, ctx)?;
565
566            let in_array_size = self.new_generic_tluint(target.loc());
567            let in_array_type = TypeVar::array(expression.loc(), inner_type.clone(), in_array_size.clone()).insert(self);
568            let out_array_size = self.new_generic_tluint(target.loc());
569            let out_array_type = TypeVar::array(expression.loc(), inner_type.clone(), out_array_size.clone()).insert(self);
570
571            let out_size_constraint = ConstraintExpr::Var(end_var.clone()) - ConstraintExpr::Var(start_var.clone());
572            self.add_constraint(out_array_size, out_size_constraint, expression.loc(), &out_array_type, ConstraintSource::RangeIndex);
573
574            self.add_requirement(Requirement::RangeIndexEndAfterStart { expr: expression.loc(), start: start_var.clone().at_loc(&start), end: end_var.clone().at_loc(end) });
575            self.add_requirement(Requirement::RangeIndexInArray { index: end_var.at_loc(end), size: in_array_size.at_loc(&target.loc()) });
576
577            self.unify(&expression.inner, &out_array_type, ctx)
578                .into_default_diagnostic(expression, self)?;
579
580
581            self.unify(&target.inner, &in_array_type, ctx)
582                .into_diagnostic(target.as_ref(), |diag, Tm{e: _expected, g: got}| {
583                    let got = got.display(self);
584                    diag
585                        .message(format!("Index target must be an array, got {}", got))
586                        .primary_label("Expected array".to_string())
587                }, self)?;
588        });
589        Ok(())
590    }
591
592    #[trace_typechecker]
593    #[tracing::instrument(level = "trace", skip_all)]
594    pub fn visit_block_expr(
595        &mut self,
596        expression: &Loc<Expression>,
597        ctx: &Context,
598        generic_list: &GenericListToken,
599    ) -> Result<()> {
600        assuming_kind!(ExprKind::Block(block) = expression => {
601            self.visit_block(block, ctx, generic_list)?;
602
603            if let Some(result) = &block.result {
604                // Unify the return type of the block with the type of this expression
605                self.unify(&expression.inner, &result.inner, ctx)
606                    // NOTE: We could be more specific about this error specifying
607                    // that the type of the block must match the return type, though
608                    // that might just be spammy.
609                    .into_default_diagnostic(result, self)?;
610            } else {
611                // Block without return value. Unify with unit type.
612                expression
613                    .inner
614                    .unify_with(&TypeVar::unit(expression.loc()).insert(self), self)
615                    .commit(self, ctx)
616                    .into_diagnostic(Loc::nowhere(()), |err, Tm{g: _, e: _}| {
617                        diag_anyhow!(
618                            Loc::nowhere(()),
619                            "This error shouldn't be possible: {err:?}"
620                        )}, self)?;
621            }
622        });
623        Ok(())
624    }
625
626    #[trace_typechecker]
627    #[tracing::instrument(level = "trace", skip_all)]
628    pub fn visit_if(
629        &mut self,
630        expression: &Loc<Expression>,
631        ctx: &Context,
632        generic_list: &GenericListToken,
633    ) -> Result<()> {
634        assuming_kind!(ExprKind::If(cond, on_true, on_false) = &expression => {
635            self.visit_expression(cond, ctx, generic_list);
636            self.visit_expression(on_true, ctx, generic_list);
637            self.visit_expression(on_false, ctx, generic_list);
638
639            cond
640                .inner
641                .unify_with(&self.t_bool(cond.loc(), ctx.symtab), self)
642                .commit(self, ctx)
643                .into_diagnostic(cond.as_ref(), |diag, Tm{e: _expected, g: got}| {
644                    let got = got.display(self);
645                    diag.
646                        message(format!("If condition must be a bool, got {}", got))
647                        .primary_label("Expected boolean")
648                }, self)?;
649            self.unify(&on_false.inner, &on_true.inner, ctx)
650                .into_diagnostic(on_false.as_ref(), |diag, tm| {
651                    let (expected, got) = tm.display_e_g(self);
652                    diag.message("If branches have incompatible type")
653                        .primary_label(format!("But this has type {got}"))
654                        .secondary_label(on_true.as_ref(), format!("This branch has type {expected}"))
655                }, self)?;
656            self.unify(expression, &on_false.inner, ctx)
657                .into_default_diagnostic(expression, self)?;
658        });
659        Ok(())
660    }
661
662    #[trace_typechecker]
663    #[tracing::instrument(level = "trace", skip_all)]
664    pub fn visit_match(
665        &mut self,
666        expression: &Loc<Expression>,
667        ctx: &Context,
668        generic_list: &GenericListToken,
669    ) -> Result<()> {
670        assuming_kind!(ExprKind::Match(cond, branches) = &expression => {
671            self.visit_expression(cond, ctx, generic_list);
672
673            for (i, (pattern, result)) in branches.iter().enumerate() {
674                self.visit_pattern(pattern, ctx, generic_list)?;
675
676                self.unify(pattern, &cond.inner, ctx)
677                    .into_default_diagnostic(pattern, self)?;
678
679                self.visit_expression(result, ctx, generic_list);
680
681                if i != 0 {
682                    self.unify(&branches[0].1, result, ctx).into_diagnostic(
683                        result,
684                        |diag, tm| {
685                            let (expected, got) = tm.display_e_g(self);
686                            diag.message("Match branches have incompatible type")
687                                .primary_label(format!("This branch has type {got}"))
688                                .secondary_label(&branches[0].1, format!("But this one has type {expected}"))
689                        }, self
690                    )?;
691                }
692            }
693
694            assert!(
695                !branches.is_empty(),
696                "Empty match statements should be checked by ast lowering"
697            );
698
699            self.unify_expression_generic_error(&branches[0].1, expression, ctx)?;
700        });
701        Ok(())
702    }
703
704    #[trace_typechecker]
705    #[tracing::instrument(level = "trace", skip_all)]
706    pub fn visit_binary_operator(
707        &mut self,
708        expression: &Loc<Expression>,
709        ctx: &Context,
710        generic_list: &GenericListToken,
711    ) -> Result<()> {
712        assuming_kind!(ExprKind::BinaryOperator(lhs, op, rhs) = &expression => {
713            self.visit_expression(lhs, ctx, generic_list);
714            self.visit_expression(rhs, ctx, generic_list);
715            match op.inner {
716                BinaryOperator::Add
717                | BinaryOperator::Sub => {
718                    let (in_t, lhs_size) = self.new_generic_number(expression.loc(), ctx);
719                    let (result_t, result_size) = self.new_generic_number(expression.loc(), ctx);
720
721                    self.add_constraint(
722                        result_size.clone(),
723                        ce_var(&lhs_size) + ce_int(BigInt::one()),
724                        expression.loc(),
725                        &result_t,
726                        ConstraintSource::AdditionOutput
727                    );
728                    self.add_constraint(
729                        lhs_size.clone(),
730                        ce_var(&result_size) + -ce_int(BigInt::one()),
731                        lhs.loc(),
732                        &in_t,
733                        ConstraintSource::AdditionOutput
734                    );
735
736                    self.unify_expression_generic_error(lhs, &in_t, ctx)?;
737                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
738                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
739
740                    self.add_requirement(Requirement::SharedBase(vec![
741                        in_t.at_loc(lhs),
742                        result_t.at_loc(expression)
743                    ]));
744
745                }
746                BinaryOperator::Mul => {
747                    let (lhs_t, lhs_size) = self.new_generic_number(expression.loc(), ctx);
748                    let (rhs_t, rhs_size) = self.new_generic_number(expression.loc(), ctx);
749                    let (result_t, result_size) = self.new_generic_number(expression.loc(), ctx);
750
751                    // Result size is sum of input sizes
752                    self.add_constraint(
753                        result_size.clone(),
754                        ce_var(&lhs_size) + ce_var(&rhs_size),
755                        expression.loc(),
756                        &result_t,
757                        ConstraintSource::MultOutput
758                    );
759                    self.add_constraint(
760                        lhs_size.clone(),
761                        ce_var(&result_size) + -ce_var(&rhs_size),
762                        lhs.loc(),
763                        &lhs_t,
764                        ConstraintSource::MultOutput
765                    );
766                    self.add_constraint(rhs_size.clone(),
767                        ce_var(&result_size) + -ce_var(&lhs_size),
768                        rhs.loc(),
769                        &rhs_t
770                        , ConstraintSource::MultOutput
771                    );
772
773                    self.unify_expression_generic_error(lhs, &lhs_t, ctx)?;
774                    self.unify_expression_generic_error(rhs, &rhs_t, ctx)?;
775                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
776
777                    self.add_requirement(Requirement::SharedBase(vec![
778                        lhs_t.at_loc(lhs),
779                        rhs_t.at_loc(rhs),
780                        result_t.at_loc(expression)
781                    ]));
782                }
783                // Division, being integer division has the same width out as in
784                BinaryOperator::Div | BinaryOperator::Mod => {
785                    let (int_type, _size) = self.new_generic_number(expression.loc(), ctx);
786
787                    self.unify_expression_generic_error(lhs, &int_type, ctx)?;
788                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
789                    self.unify_expression_generic_error(expression, &rhs.inner, ctx)?;
790                },
791                // Shift operators have the same width in as they do out
792                BinaryOperator::LeftShift
793                | BinaryOperator::BitwiseAnd
794                | BinaryOperator::BitwiseXor
795                | BinaryOperator::BitwiseOr
796                | BinaryOperator::ArithmeticRightShift
797                | BinaryOperator::RightShift => {
798                    let (int_type, _size) = self.new_generic_number(expression.loc(), ctx);
799
800                    // FIXME: Make generic over types that can be bitmanipulated
801                    self.unify_expression_generic_error(lhs, &int_type, ctx)?;
802                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
803                    self.unify_expression_generic_error(expression, &rhs.inner, ctx)?;
804                }
805                BinaryOperator::Eq
806                | BinaryOperator::NotEq
807                | BinaryOperator::Gt
808                | BinaryOperator::Lt
809                | BinaryOperator::Ge
810                | BinaryOperator::Le => {
811                    let (base, _size) = self.new_generic_number(expression.loc(), ctx);
812                    // FIXME: Make generic over types that can be compared
813                    self.unify_expression_generic_error(lhs, &base, ctx)?;
814                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
815                    expression
816                        .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
817                        .commit(self, ctx)
818                        .into_default_diagnostic(expression.loc(), self)?;
819                }
820                BinaryOperator::LogicalAnd
821                | BinaryOperator::LogicalOr
822                | BinaryOperator::LogicalXor => {
823                    lhs
824                        .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
825                        .commit(self, ctx)
826                        .into_default_diagnostic(expression.loc(), self)?;
827                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
828
829                    expression
830                        .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
831                        .commit(self, ctx)
832                        .into_default_diagnostic(expression, self)?;
833                }
834            }
835        });
836        Ok(())
837    }
838
839    #[trace_typechecker]
840    #[tracing::instrument(level = "trace", skip_all)]
841    pub fn visit_unary_operator(
842        &mut self,
843        expression: &Loc<Expression>,
844        ctx: &Context,
845        generic_list: &GenericListToken,
846    ) -> Result<()> {
847        assuming_kind!(ExprKind::UnaryOperator(op, operand) = &expression => {
848            self.visit_expression(operand, ctx, generic_list);
849            match &op.inner {
850                UnaryOperator::Sub => {
851                    let int_type = self.new_generic_int(expression.loc(), ctx.symtab).insert(self);
852                    self.unify_expression_generic_error(operand, &int_type, ctx)?;
853                    self.unify_expression_generic_error(expression, &int_type, ctx)?
854                }
855                UnaryOperator::BitwiseNot => {
856                    let (number_type, _) = self.new_generic_number(expression.loc(), ctx);
857                    self.unify_expression_generic_error(operand, &number_type, ctx)?;
858                    self.unify_expression_generic_error(expression, &number_type, ctx)?
859                }
860                UnaryOperator::Not => {
861                    let bool = self.t_bool(expression.loc(), ctx.symtab);
862                    self.unify_expression_generic_error(operand, &bool, ctx)?;
863                    self.unify_expression_generic_error(expression, &bool, ctx)?
864                }
865                UnaryOperator::Dereference => {
866                    let result_type = self.new_generic_type(expression.loc());
867                    let reference_type = TypeVar::wire(expression.loc(), result_type.clone()).insert(self);
868                    self.unify_expression_generic_error(operand, &reference_type, ctx)?;
869                    self.unify_expression_generic_error(expression, &result_type, ctx)?
870                }
871                UnaryOperator::Reference => {
872                    let result_type = self.new_generic_type(expression.loc());
873                    let reference_type = TypeVar::wire(expression.loc(), result_type.clone()).insert(self);
874                    self.unify_expression_generic_error(operand, &result_type, ctx)?;
875                    self.unify_expression_generic_error(expression, &reference_type, ctx)?
876                }
877            }
878        });
879        Ok(())
880    }
881}