spade_typeinference/
expression.rs

1use num::{BigInt, One};
2use spade_common::location_info::{Loc, WithLocation};
3use spade_common::name::Identifier;
4use spade_common::num_ext::InfallibleToBigInt;
5use spade_diagnostics::diagnostic::DiagnosticLevel;
6use spade_diagnostics::{diag_anyhow, Diagnostic};
7use spade_hir::expression::{BinaryOperator, IntLiteralKind, NamedArgument, UnaryOperator};
8use spade_hir::{ExprKind, Expression};
9use spade_macros::trace_typechecker;
10use spade_types::meta_types::MetaType;
11use spade_types::KnownType;
12
13use crate::constraints::{bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource};
14use crate::equation::{TypeVar, TypedExpression};
15use crate::error::{TypeMismatch as Tm, UnificationErrorExt};
16use crate::requirements::{ConstantInt, Requirement};
17use crate::{Context, GenericListToken, HasType, Result, TraceStackEntry, TypeState};
18
19macro_rules! assuming_kind {
20    ($pattern:pat = $expr:expr => $block:block) => {
21        if let $pattern = &$expr.inner.kind {
22            $block
23        } else {
24            panic!("Incorrect assumption about expression kind")
25        };
26    };
27}
28
29impl TypeState {
30    #[trace_typechecker]
31    #[tracing::instrument(level = "trace", skip_all)]
32    pub fn visit_identifier(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
33        assuming_kind!(ExprKind::Identifier(ident) = &expression => {
34            // Add an equation for the anonymous id
35            self.unify_expression_generic_error(
36                expression,
37                &TypedExpression::Name(ident.clone()),
38                ctx
39            )?;
40        });
41        Ok(())
42    }
43
44    #[trace_typechecker]
45    #[tracing::instrument(level = "trace", skip_all)]
46    pub fn visit_type_level_integer(
47        &mut self,
48        expression: &Loc<Expression>,
49        generic_list: &GenericListToken,
50        ctx: &Context,
51    ) -> Result<()> {
52        assuming_kind!(ExprKind::TypeLevelInteger(value) = &expression => {
53            let (t, _size) = self.new_generic_number(expression.loc(), ctx);
54            self.unify(&t, &expression.inner, ctx)
55                .into_diagnostic(expression.loc(), |diag, _tm| {
56                    diag
57                        .level(DiagnosticLevel::Bug)
58                        .message("Failed to unify integer literal with integer")
59                }, self)?;
60            let generic = self
61                .get_generic_list(generic_list)
62                .ok_or_else(|| {
63                    diag_anyhow!(expression, "Found no generic list here")
64                })?
65                .get(value).ok_or_else(|| {
66                Diagnostic::bug(expression, "Found entry for {value} in generic list")
67            })?;
68            self.add_requirement(
69                Requirement::FitsIntLiteral {
70                    value: ConstantInt::Generic(generic.clone()),
71                    target_type: t.at_loc(expression)
72                }
73            )
74        });
75        Ok(())
76    }
77
78    #[tracing::instrument(level = "trace", skip_all)]
79    pub fn visit_pipeline_ref(
80        &mut self,
81        expression: &Loc<Expression>,
82        generic_list: &GenericListToken,
83        ctx: &Context,
84    ) -> Result<()> {
85        assuming_kind!(ExprKind::PipelineRef{stage, name, declares_name, depth_typeexpr_id} = &expression => {
86            // If this reference declares the referenced name, add a new equation
87            if *declares_name {
88                let new_var = self.new_generic_type(expression.loc());
89                self.add_equation(TypedExpression::Name(name.clone().inner), new_var)
90            }
91
92            let depth = self.new_generic_tlint(stage.loc());
93            self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth.clone());
94            let depth = match &stage.inner {
95                spade_hir::expression::PipelineRefKind::Absolute(name) => {
96                    let key = TypedExpression::Name(name.inner.clone());
97                    let var = if !self.equations.contains_key(&key) {
98                        let var = self.new_generic_tlint(stage.loc());
99                        self.add_equation(key.clone(), var.clone());
100                        self.trace_stack.push(TraceStackEntry::PreAddingPipelineLabel(name.inner.clone(), var.debug_resolve(self)));
101                        var
102                    } else {
103                        let var = self.equations.get(&key).unwrap().clone();
104                        self.trace_stack.push(TraceStackEntry::RecoveringPipelineLabel(name.inner.clone(), var.debug_resolve(self)));
105                        var
106                    };
107                    // NOTE: Safe unwrap, depth is fresh
108                    self.unify(&depth, &var, ctx).unwrap()
109                },
110                spade_hir::expression::PipelineRefKind::Relative(expr) => {
111                    let expr_var = self.hir_type_expr_to_var(expr, generic_list)?;
112                    let total_offset = self.new_generic_tlint(stage.loc());
113                    self.add_constraint(
114                        total_offset.clone(),
115                        ConstraintExpr::Sum(
116                            Box::new(ConstraintExpr::Var(expr_var)),
117                            Box::new(ConstraintExpr::Var(self.get_pipeline_state(expression)?
118                                .current_stage_depth.clone()))
119                        ),
120                        stage.loc(),
121                        &total_offset,
122                        ConstraintSource::PipelineRegOffset{reg: expr.loc(), total: self.get_pipeline_state(expr)?.total_depth.loc()}
123                    );
124                    // Safe unwrap, depth is a fresh type var
125                    self.unify(&depth, &total_offset, ctx).unwrap()
126                },
127            };
128
129            let pipeline_state = self.pipeline_state
130                .as_ref()
131                .ok_or_else(|| diag_anyhow!(
132                    expression,
133                    "Expected a pipeline state"
134                ))?;
135            self.add_requirement(Requirement::ValidPipelineOffset {
136                definition_depth: pipeline_state
137                    .total_depth
138                    .clone(),
139                current_stage: pipeline_state.current_stage_depth.clone().nowhere(),
140                reference_offset: depth.at_loc(stage)
141            });
142
143            // Add an equation for the anonymous id
144            self.unify_expression_generic_error(
145                expression,
146                &TypedExpression::Name(name.clone().inner),
147                ctx
148            )?;
149        });
150        Ok(())
151    }
152
153    #[trace_typechecker]
154    #[tracing::instrument(level = "trace", skip_all)]
155    pub fn visit_int_literal(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
156        assuming_kind!(ExprKind::IntLiteral(value, kind) = &expression => {
157            let (t, _size) = match kind {
158                IntLiteralKind::Unsized => self.new_generic_number(expression.loc(), ctx),
159                IntLiteralKind::Signed(size) => {
160                    let (t, size_var) = self.new_split_generic_int(expression.loc(), ctx.symtab);
161                    // NOTE: Safe unwrap, we're unifying a generic int with a size
162                    size_var
163                        .unify_with(&TypeVar::Known(
164                            expression.loc(),
165                            KnownType::Integer(size.to_bigint()),
166                            vec![]).insert(self),
167                            self
168                        )
169                        .commit(self, ctx)
170                        .unwrap();
171                    (t, size_var)
172                },
173                IntLiteralKind::Unsigned(size) => {
174                    let (t, size_var) = self.new_split_generic_uint(expression.loc(), ctx.symtab);
175                    // NOTE: Safe unwrap, we're unifying a generic int with a size
176                    size_var
177                        .unify_with(&self.new_concrete_int(size.clone(), expression.loc()), self)
178                        .commit(self, ctx)
179                        .unwrap();
180                    (t, size_var)
181                }
182            };
183            self.unify(&t, &expression.inner, ctx)
184                .into_diagnostic(expression.loc(), |diag, Tm{e: _, g: _got}| {
185                    diag
186                        .level(DiagnosticLevel::Bug)
187                        .message("Failed to unify integer literal with integer")
188                }, self)?;
189            self.add_requirement(Requirement::FitsIntLiteral {
190                value: ConstantInt::Literal(value.clone()),
191                target_type: t.at_loc(expression)
192            });
193        });
194        Ok(())
195    }
196
197    #[trace_typechecker]
198    #[tracing::instrument(level = "trace", skip_all)]
199    pub fn visit_bool_literal(
200        &mut self,
201        expression: &Loc<Expression>,
202        ctx: &Context,
203    ) -> Result<()> {
204        assuming_kind!(ExprKind::BoolLiteral(_) = &expression => {
205            expression
206                .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
207                .commit(self, ctx)
208                .into_default_diagnostic(expression, self)?;
209        });
210        Ok(())
211    }
212
213    #[trace_typechecker]
214    #[tracing::instrument(level = "trace", skip_all)]
215    pub fn visit_bit_literal(&mut self, expression: &Loc<Expression>, ctx: &Context) -> Result<()> {
216        assuming_kind!(ExprKind::BitLiteral(_) = &expression => {
217            expression
218                .unify_with(&self.t_bit(expression.loc(), ctx.symtab), self)
219                .commit(self, ctx)
220                .into_default_diagnostic(expression, self)?
221        });
222        Ok(())
223    }
224
225    #[trace_typechecker]
226    #[tracing::instrument(level = "trace", skip_all)]
227    pub fn visit_tuple_literal(
228        &mut self,
229        expression: &Loc<Expression>,
230        ctx: &Context,
231        generic_list: &GenericListToken,
232    ) -> Result<()> {
233        assuming_kind!(ExprKind::TupleLiteral(inner) = &expression => {
234            for expr in inner {
235                self.visit_expression(expr, ctx, generic_list);
236                // NOTE: safe unwrap, we know this expr has a type because we just visited
237            }
238
239            let mut inner_types = vec![];
240            for expr in inner {
241                let t = self.type_of(&TypedExpression::Id(expr.id));
242
243                inner_types.push(t);
244            }
245
246            expression
247                .unify_with(
248                    &TypeVar::Known(expression.loc(), KnownType::Tuple, inner_types).insert(self),
249                    self
250                )
251                .commit(self, ctx)
252                .into_default_diagnostic(expression, self)?
253        });
254        Ok(())
255    }
256
257    #[trace_typechecker]
258    #[tracing::instrument(level = "trace", skip_all)]
259    pub fn visit_tuple_index(
260        &mut self,
261        expression: &Loc<Expression>,
262        ctx: &Context,
263        generic_list: &GenericListToken,
264    ) -> Result<()> {
265        assuming_kind!(ExprKind::TupleIndex(tup, index) = &expression => {
266            self.visit_expression(tup, ctx, generic_list);
267            let t_id = self.type_of(&TypedExpression::Id(tup.id));
268
269            let inner_types = match t_id.resolve(self) {
270                TypeVar::Known(_, KnownType::Tuple, inner) => inner,
271                t @ TypeVar::Known(ref other_source, _, _) => {
272                    return Err(Diagnostic::error(tup.loc(), "Attempt to use tuple indexing on non-tuple")
273                        .primary_label(format!("expected tuple, got {t}", t = t.display(self)))
274                        .secondary_label(index, "Because this is a tuple index")
275                        .secondary_label(other_source, format!("Type {t} inferred here", t = t.display(self)))
276                    );
277                }
278                TypeVar::Unknown(_, _, _, MetaType::Type | MetaType::Any) => {
279                    return Err(
280                        Diagnostic::error(tup.as_ref(), "Type of tuple indexee must be known at this point")
281                            .primary_label("The type of this must be known")
282                    )
283                }
284                TypeVar::Unknown(ref other_source, _, _, meta @ (MetaType::Uint | MetaType::Int | MetaType::Number | MetaType::Bool)) => {
285                    return Err(
286                        Diagnostic::error(tup.as_ref(), "Cannot use tuple indexing on a type level number")
287                            .primary_label("Tuple indexing on type level number")
288                        .secondary_label(other_source, format!("Meta-type {meta} inferred here"))
289                    )
290                }
291            };
292
293            if (index.inner as usize) < inner_types.len() {
294                let true_inner_type = inner_types[index.inner as usize].clone();
295                self.unify_expression_generic_error(
296                    expression,
297                    &true_inner_type,
298                    ctx
299                )?
300            } else {
301                return Err(Diagnostic::error(index, "Tuple index out of bounds")
302                    .primary_label(format!("Tuple only has {} elements", inner_types.len()))
303                    .note(format!("     Index: {}", index))
304                    .note(format!("Tuple size: {}", inner_types.len()))
305                );
306            }
307        });
308        Ok(())
309    }
310
311    #[trace_typechecker]
312    #[tracing::instrument(level = "trace", skip_all)]
313    pub fn visit_field_access(
314        &mut self,
315        expression: &Loc<Expression>,
316        ctx: &Context,
317        generic_list: &GenericListToken,
318    ) -> Result<()> {
319        assuming_kind!(ExprKind::FieldAccess(target, field) = &expression => {
320            self.visit_expression(target, ctx, generic_list);
321
322            let target_type = self.type_of(&TypedExpression::Id(target.id));
323            let self_type = self.type_of(&TypedExpression::Id(expression.id));
324
325            let requirement = Requirement::HasField {
326                target_type: target_type.at_loc(target),
327                field: field.clone(),
328                expr: self_type.at_loc(expression)
329            };
330
331            requirement.check_or_add(self, ctx)?;
332        });
333        Ok(())
334    }
335
336    #[trace_typechecker]
337    #[tracing::instrument(level = "trace", skip_all)]
338    pub fn visit_method_call(
339        &mut self,
340        expression: &Loc<Expression>,
341        ctx: &Context,
342        generic_list: &GenericListToken,
343    ) -> Result<()> {
344        assuming_kind!(ExprKind::MethodCall{call_kind, target, name, args, turbofish} = &expression => {
345            // NOTE: We don't visit_expression here as it is being added to the argument_list
346            // which we *do* visit
347            // self.visit_expression(target, ctx, generic_list)?;
348
349            let args_with_self = args.clone().map(|mut args| {
350                match &mut args {
351                    spade_hir::ArgumentList::Named(inner) => {
352                        inner.push(NamedArgument::Full(
353                            Identifier("self".to_string()).at_loc(target),
354                            target.as_ref().clone()
355                        ))
356                    },
357                    spade_hir::ArgumentList::Positional(list) => list.insert(0, target.as_ref().clone()),
358                };
359                args
360            });
361
362            self.visit_argument_list(&args_with_self, ctx, generic_list)?;
363
364            let target_type = self.type_of(&TypedExpression::Id(target.id));
365            let self_type = self.type_of(&TypedExpression::Id(expression.id));
366
367            let trait_list = if let TypeVar::Unknown(_, _, trait_list, MetaType::Type) = &target_type.resolve(self) {
368                if !trait_list.inner.is_empty() {
369                    Some(trait_list.clone())
370                } else {
371                    None
372                }
373            } else {
374                None
375            };
376
377            let requirement = Requirement::HasMethod {
378                expr_id: expression.map_ref(|e| e.id),
379                target_type: target_type.at_loc(target),
380                trait_list,
381                method: name.clone(),
382                expr: self_type.at_loc(expression),
383                args: args_with_self,
384                turbofish: turbofish.clone(),
385                prev_generic_list: generic_list.clone(),
386                call_kind: call_kind.clone()
387            };
388
389            requirement.check_or_add(self, ctx)?
390        });
391        Ok(())
392    }
393
394    #[trace_typechecker]
395    #[tracing::instrument(level = "trace", skip_all)]
396    pub fn visit_array_literal(
397        &mut self,
398        expression: &Loc<Expression>,
399        ctx: &Context,
400        generic_list: &GenericListToken,
401    ) -> Result<()> {
402        assuming_kind!(ExprKind::ArrayLiteral(members) = &expression => {
403            for expr in members {
404                self.visit_expression(expr, ctx, generic_list);
405            }
406
407            // unify all elements in array pairwise, e.g. unify(0, 1), unify(1, 2), ...
408            for (l, r) in members.iter().zip(members.iter().skip(1)) {
409                self.unify(r, l, ctx)
410                    .into_diagnostic(r, |diag, Tm{e: expected, g: _got}| {
411                        let expected = expected.display(self);
412                        diag.message(format!(
413                            "Array element type mismatch. Expected {}",
414                            expected
415                        ))
416                        .primary_label(format!("Expected {}", expected))
417                        .secondary_label(members.first().unwrap().loc(), "To match this".to_string())
418                    }, self)?;
419            }
420
421            let inner_type = if members.is_empty() {
422                self.new_generic_type(expression.loc())
423            }
424            else {
425                members[0].get_type(self)
426            };
427
428            let size_type = TypeVar::Known(expression.loc(), KnownType::Integer(members.len().to_bigint()), vec![]).insert(self);
429            let result_type = TypeVar::array(
430                expression.loc(),
431                inner_type,
432                size_type,
433            ).insert(self);
434
435            self.unify_expression_generic_error(expression, &result_type, ctx)?;
436        });
437        Ok(())
438    }
439
440    pub fn visit_array_shorthand_literal(
441        &mut self,
442        expression: &Loc<Expression>,
443        ctx: &Context,
444        generic_list: &GenericListToken,
445    ) -> Result<()> {
446        assuming_kind!(ExprKind::ArrayShorthandLiteral(expr, amount) = &expression => {
447            self.visit_expression(expr, ctx, generic_list);
448
449
450            let inner_type = expr.get_type(self);
451            let size_type = self.visit_const_generic_with_id(amount, generic_list, ConstraintSource::ArraySize, ctx)?;
452            // Force the type to be a uint
453            let uint_type = self.new_generic_tluint(expression.loc());
454            self.unify(&size_type, &uint_type, ctx).into_default_diagnostic(expression.loc(), self)?;
455
456            let result_type = TypeVar::array(expression.loc(), inner_type, size_type).insert(self);
457
458            self.unify_expression_generic_error(expression, &result_type, ctx)?;
459        });
460        Ok(())
461    }
462
463    #[trace_typechecker]
464    #[tracing::instrument(level = "trace", skip_all)]
465    pub fn visit_create_ports(
466        &mut self,
467        expression: &Loc<Expression>,
468        ctx: &Context,
469        _generic_list: &GenericListToken,
470    ) -> Result<()> {
471        assuming_kind!(ExprKind::CreatePorts = &expression => {
472            let inner_type = self.new_generic_type(expression.loc());
473            let inverted = TypeVar::Known(expression.loc(), KnownType::Inverted, vec![inner_type.clone()]).insert(self);
474            let compound = TypeVar::tuple(expression.loc(), vec![inner_type, inverted]).insert(self);
475            self.unify_expression_generic_error(expression, &compound, ctx)?;
476        });
477        Ok(())
478    }
479
480    #[trace_typechecker]
481    #[tracing::instrument(level = "trace", skip_all)]
482    pub fn visit_index(
483        &mut self,
484        expression: &Loc<Expression>,
485        ctx: &Context,
486        generic_list: &GenericListToken,
487    ) -> Result<()> {
488        assuming_kind!(ExprKind::Index(target, index) = &expression => {
489            // Visit child nodes
490            self.visit_expression(target, ctx, generic_list);
491            self.visit_expression(index, ctx, generic_list);
492
493            // Add constraints
494            let inner_type = self.new_generic_type(expression.loc());
495
496            // Unify inner type with this expression
497            self.unify_expression_generic_error(
498                expression,
499                &inner_type,
500                ctx
501            )?;
502
503            let array_size = self.new_generic_tluint(expression.loc());
504            let (int_type, int_size) = self.new_split_generic_uint(index.loc(), ctx.symtab);
505
506            // NOTE[et]: Only used for size constraints of this exact type - this can be a
507            // requirement instead, that way we remove a lot of complexity! :D
508            self.add_constraint(
509                int_size,
510                bits_to_store(ce_var(&array_size) - ce_int(BigInt::one())),
511                index.loc(),
512                &int_type,
513                ConstraintSource::ArrayIndexing
514            );
515
516            self.unify(&index.inner, &int_type, ctx)
517                .into_diagnostic(index.as_ref(), |diag, Tm{e: _expected, g: got}| {
518                    let got = got.display(self);
519                    diag.message(format!("Index must be an integer, got {}", got))
520                        .primary_label("Expected integer".to_string())
521                }, self)?;
522
523            let array_type = TypeVar::array(
524                expression.loc(),
525                expression.get_type(self),
526                array_size.clone()
527            ).insert(self);
528            self.add_requirement(Requirement::ArrayIndexeeIsNonZero {
529                index: index.loc(),
530                array: array_type.clone().at_loc(target),
531                array_size: array_size.clone().at_loc(index)
532            });
533            self.unify(&target.inner, &array_type, ctx)
534                .into_diagnostic(target.as_ref(), |diag, Tm{e: _expected, g: got}| {
535                    let got = got.display(self);
536                    diag
537                        .message(format!("Index target must be an array, got {}", got))
538                        .primary_label("Expected array".to_string())
539                }, self)?;
540        });
541        Ok(())
542    }
543
544    #[trace_typechecker]
545    #[tracing::instrument(level = "trace", skip_all)]
546    pub fn visit_range_index(
547        &mut self,
548        expression: &Loc<Expression>,
549        ctx: &Context,
550        generic_list: &GenericListToken,
551    ) -> Result<()> {
552        assuming_kind!(ExprKind::RangeIndex{
553            target,
554            ref start,
555            ref end,
556        } = &expression => {
557            self.visit_expression(target, ctx, generic_list);
558            // Add constraints
559            let inner_type = self.new_generic_type(target.loc());
560
561            let start_var = self.visit_const_generic_with_id(start, generic_list, ConstraintSource::RangeIndex, ctx)?;
562            let end_var = self.visit_const_generic_with_id(end, generic_list, ConstraintSource::RangeIndex, ctx)?;
563
564            let in_array_size = self.new_generic_tluint(target.loc());
565            let in_array_type = TypeVar::array(expression.loc(), inner_type.clone(), in_array_size.clone()).insert(self);
566            let out_array_size = self.new_generic_tluint(target.loc());
567            let out_array_type = TypeVar::array(expression.loc(), inner_type.clone(), out_array_size.clone()).insert(self);
568
569            let out_size_constraint = ConstraintExpr::Var(end_var.clone()) - ConstraintExpr::Var(start_var.clone());
570            self.add_constraint(out_array_size, out_size_constraint, expression.loc(), &out_array_type, ConstraintSource::RangeIndex);
571
572            self.add_requirement(Requirement::RangeIndexEndAfterStart { expr: expression.loc(), start: start_var.clone().at_loc(&start), end: end_var.clone().at_loc(end) });
573            self.add_requirement(Requirement::RangeIndexInArray { index: end_var.at_loc(end), size: in_array_size.at_loc(&target.loc()) });
574
575            self.unify(&expression.inner, &out_array_type, ctx)
576                .into_default_diagnostic(expression, self)?;
577
578
579            self.unify(&target.inner, &in_array_type, ctx)
580                .into_diagnostic(target.as_ref(), |diag, Tm{e: _expected, g: got}| {
581                    let got = got.display(self);
582                    diag
583                        .message(format!("Index target must be an array, got {}", got))
584                        .primary_label("Expected array".to_string())
585                }, self)?;
586        });
587        Ok(())
588    }
589
590    #[trace_typechecker]
591    #[tracing::instrument(level = "trace", skip_all)]
592    pub fn visit_block_expr(
593        &mut self,
594        expression: &Loc<Expression>,
595        ctx: &Context,
596        generic_list: &GenericListToken,
597    ) -> Result<()> {
598        assuming_kind!(ExprKind::Block(block) = expression => {
599            self.visit_block(block, ctx, generic_list)?;
600
601            if let Some(result) = &block.result {
602                // Unify the return type of the block with the type of this expression
603                self.unify(&expression.inner, &result.inner, ctx)
604                    // NOTE: We could be more specific about this error specifying
605                    // that the type of the block must match the return type, though
606                    // that might just be spammy.
607                    .into_default_diagnostic(result, self)?;
608            } else {
609                // Block without return value. Unify with unit type.
610                expression
611                    .inner
612                    .unify_with(&TypeVar::unit(expression.loc()).insert(self), self)
613                    .commit(self, ctx)
614                    .into_diagnostic(Loc::nowhere(()), |err, Tm{g: _, e: _}| {
615                        diag_anyhow!(
616                            Loc::nowhere(()),
617                            "This error shouldn't be possible: {err:?}"
618                        )}, self)?;
619            }
620        });
621        Ok(())
622    }
623
624    #[trace_typechecker]
625    #[tracing::instrument(level = "trace", skip_all)]
626    pub fn visit_if(
627        &mut self,
628        expression: &Loc<Expression>,
629        ctx: &Context,
630        generic_list: &GenericListToken,
631    ) -> Result<()> {
632        assuming_kind!(ExprKind::If(cond, on_true, on_false) = &expression => {
633            self.visit_expression(cond, ctx, generic_list);
634            self.visit_expression(on_true, ctx, generic_list);
635            self.visit_expression(on_false, ctx, generic_list);
636
637            cond
638                .inner
639                .unify_with(&self.t_bool(cond.loc(), ctx.symtab), self)
640                .commit(self, ctx)
641                .into_diagnostic(cond.as_ref(), |diag, Tm{e: _expected, g: got}| {
642                    let got = got.display(self);
643                    diag.
644                        message(format!("If condition must be a bool, got {}", got))
645                        .primary_label("Expected boolean")
646                }, self)?;
647            self.unify(&on_false.inner, &on_true.inner, ctx)
648                .into_diagnostic(on_false.as_ref(), |diag, tm| {
649                    let (expected, got) = tm.display_e_g(self);
650                    diag.message("If branches have incompatible type")
651                        .primary_label(format!("But this has type {got}"))
652                        .secondary_label(on_true.as_ref(), format!("This branch has type {expected}"))
653                }, self)?;
654            self.unify(expression, &on_false.inner, ctx)
655                .into_default_diagnostic(expression, self)?;
656        });
657        Ok(())
658    }
659
660    #[trace_typechecker]
661    #[tracing::instrument(level = "trace", skip_all)]
662    pub fn visit_match(
663        &mut self,
664        expression: &Loc<Expression>,
665        ctx: &Context,
666        generic_list: &GenericListToken,
667    ) -> Result<()> {
668        assuming_kind!(ExprKind::Match(cond, branches) = &expression => {
669            self.visit_expression(cond, ctx, generic_list);
670
671            for (i, (pattern, result)) in branches.iter().enumerate() {
672                self.visit_pattern(pattern, ctx, generic_list)?;
673
674                self.unify(pattern, &cond.inner, ctx)
675                    .into_default_diagnostic(pattern, self)?;
676
677                self.visit_expression(result, ctx, generic_list);
678
679                if i != 0 {
680                    self.unify(&branches[0].1, result, ctx).into_diagnostic(
681                        result,
682                        |diag, tm| {
683                            let (expected, got) = tm.display_e_g(self);
684                            diag.message("Match branches have incompatible type")
685                                .primary_label(format!("This branch has type {got}"))
686                                .secondary_label(&branches[0].1, format!("But this one has type {expected}"))
687                        }, self
688                    )?;
689                }
690            }
691
692            assert!(
693                !branches.is_empty(),
694                "Empty match statements should be checked by ast lowering"
695            );
696
697            self.unify_expression_generic_error(&branches[0].1, expression, ctx)?;
698        });
699        Ok(())
700    }
701
702    #[trace_typechecker]
703    #[tracing::instrument(level = "trace", skip_all)]
704    pub fn visit_binary_operator(
705        &mut self,
706        expression: &Loc<Expression>,
707        ctx: &Context,
708        generic_list: &GenericListToken,
709    ) -> Result<()> {
710        assuming_kind!(ExprKind::BinaryOperator(lhs, op, rhs) = &expression => {
711            self.visit_expression(lhs, ctx, generic_list);
712            self.visit_expression(rhs, ctx, generic_list);
713            match op.inner {
714                BinaryOperator::Add
715                | BinaryOperator::Sub => {
716                    let (in_t, lhs_size) = self.new_generic_number(expression.loc(), ctx);
717                    let (result_t, result_size) = self.new_generic_number(expression.loc(), ctx);
718
719                    self.add_constraint(
720                        result_size.clone(),
721                        ce_var(&lhs_size) + ce_int(BigInt::one()),
722                        expression.loc(),
723                        &result_t,
724                        ConstraintSource::AdditionOutput
725                    );
726                    self.add_constraint(
727                        lhs_size.clone(),
728                        ce_var(&result_size) + -ce_int(BigInt::one()),
729                        lhs.loc(),
730                        &in_t,
731                        ConstraintSource::AdditionOutput
732                    );
733
734                    self.unify_expression_generic_error(lhs, &in_t, ctx)?;
735                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
736                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
737
738                    self.add_requirement(Requirement::SharedBase(vec![
739                        in_t.at_loc(lhs),
740                        result_t.at_loc(expression)
741                    ]));
742
743                }
744                BinaryOperator::Mul => {
745                    let (lhs_t, lhs_size) = self.new_generic_number(expression.loc(), ctx);
746                    let (rhs_t, rhs_size) = self.new_generic_number(expression.loc(), ctx);
747                    let (result_t, result_size) = self.new_generic_number(expression.loc(), ctx);
748
749                    // Result size is sum of input sizes
750                    self.add_constraint(
751                        result_size.clone(),
752                        ce_var(&lhs_size) + ce_var(&rhs_size),
753                        expression.loc(),
754                        &result_t,
755                        ConstraintSource::MultOutput
756                    );
757                    self.add_constraint(
758                        lhs_size.clone(),
759                        ce_var(&result_size) + -ce_var(&rhs_size),
760                        lhs.loc(),
761                        &lhs_t,
762                        ConstraintSource::MultOutput
763                    );
764                    self.add_constraint(rhs_size.clone(),
765                        ce_var(&result_size) + -ce_var(&lhs_size),
766                        rhs.loc(),
767                        &rhs_t
768                        , ConstraintSource::MultOutput
769                    );
770
771                    self.unify_expression_generic_error(lhs, &lhs_t, ctx)?;
772                    self.unify_expression_generic_error(rhs, &rhs_t, ctx)?;
773                    self.unify_expression_generic_error(expression, &result_t, ctx)?;
774
775                    self.add_requirement(Requirement::SharedBase(vec![
776                        lhs_t.at_loc(lhs),
777                        rhs_t.at_loc(rhs),
778                        result_t.at_loc(expression)
779                    ]));
780                }
781                // Division, being integer division has the same width out as in
782                BinaryOperator::Div | BinaryOperator::Mod => {
783                    let (int_type, _size) = self.new_generic_number(expression.loc(), ctx);
784
785                    self.unify_expression_generic_error(lhs, &int_type, ctx)?;
786                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
787                    self.unify_expression_generic_error(expression, &rhs.inner, ctx)?;
788                },
789                // Shift operators have the same width in as they do out
790                BinaryOperator::LeftShift
791                | BinaryOperator::BitwiseAnd
792                | BinaryOperator::BitwiseXor
793                | BinaryOperator::BitwiseOr
794                | BinaryOperator::ArithmeticRightShift
795                | BinaryOperator::RightShift => {
796                    let (int_type, _size) = self.new_generic_number(expression.loc(), ctx);
797
798                    // FIXME: Make generic over types that can be bitmanipulated
799                    self.unify_expression_generic_error(lhs, &int_type, ctx)?;
800                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
801                    self.unify_expression_generic_error(expression, &rhs.inner, ctx)?;
802                }
803                BinaryOperator::Eq
804                | BinaryOperator::NotEq
805                | BinaryOperator::Gt
806                | BinaryOperator::Lt
807                | BinaryOperator::Ge
808                | BinaryOperator::Le => {
809                    let (base, _size) = self.new_generic_number(expression.loc(), ctx);
810                    // FIXME: Make generic over types that can be compared
811                    self.unify_expression_generic_error(lhs, &base, ctx)?;
812                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
813                    expression
814                        .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
815                        .commit(self, ctx)
816                        .into_default_diagnostic(expression.loc(), self)?;
817                }
818                BinaryOperator::LogicalAnd
819                | BinaryOperator::LogicalOr
820                | BinaryOperator::LogicalXor => {
821                    lhs
822                        .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
823                        .commit(self, ctx)
824                        .into_default_diagnostic(expression.loc(), self)?;
825                    self.unify_expression_generic_error(lhs, &rhs.inner, ctx)?;
826
827                    expression
828                        .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
829                        .commit(self, ctx)
830                        .into_default_diagnostic(expression, self)?;
831                }
832            }
833        });
834        Ok(())
835    }
836
837    #[trace_typechecker]
838    #[tracing::instrument(level = "trace", skip_all)]
839    pub fn visit_unary_operator(
840        &mut self,
841        expression: &Loc<Expression>,
842        ctx: &Context,
843        generic_list: &GenericListToken,
844    ) -> Result<()> {
845        assuming_kind!(ExprKind::UnaryOperator(op, operand) = &expression => {
846            self.visit_expression(operand, ctx, generic_list);
847            match &op.inner {
848                UnaryOperator::Sub => {
849                    let int_type = self.new_generic_int(expression.loc(), ctx.symtab).insert(self);
850                    self.unify_expression_generic_error(operand, &int_type, ctx)?;
851                    self.unify_expression_generic_error(expression, &int_type, ctx)?
852                }
853                UnaryOperator::BitwiseNot => {
854                    let (number_type, _) = self.new_generic_number(expression.loc(), ctx);
855                    self.unify_expression_generic_error(operand, &number_type, ctx)?;
856                    self.unify_expression_generic_error(expression, &number_type, ctx)?
857                }
858                UnaryOperator::Not => {
859                    let bool = self.t_bool(expression.loc(), ctx.symtab);
860                    self.unify_expression_generic_error(operand, &bool, ctx)?;
861                    self.unify_expression_generic_error(expression, &bool, ctx)?
862                }
863                UnaryOperator::Dereference => {
864                    let result_type = self.new_generic_type(expression.loc());
865                    let reference_type = TypeVar::wire(expression.loc(), result_type.clone()).insert(self);
866                    self.unify_expression_generic_error(operand, &reference_type, ctx)?;
867                    self.unify_expression_generic_error(expression, &result_type, ctx)?
868                }
869                UnaryOperator::Reference => {
870                    let result_type = self.new_generic_type(expression.loc());
871                    let reference_type = TypeVar::wire(expression.loc(), result_type.clone()).insert(self);
872                    self.unify_expression_generic_error(operand, &result_type, ctx)?;
873                    self.unify_expression_generic_error(expression, &reference_type, ctx)?
874                }
875            }
876        });
877        Ok(())
878    }
879}