spade_typeinference/
lib.rs

1// This algorithm is based off the excellent lecture here
2// https://www.youtube.com/watch?v=xJXcZp2vgLs
3//
4// The basic idea is to go through every node in the HIR tree, for every typed thing,
5// add an equation indicating a constraint on that thing. This can onlytydone once
6// and should be done by the visitor for that node. The visitor should then unify
7// types according to the rules of the node.
8
9use std::cmp::PartialEq;
10use std::collections::{BTreeSet, HashMap};
11use std::sync::Arc;
12
13use colored::Colorize;
14use hir::expression::CallKind;
15use hir::{
16    param_util, Binding, ConstGeneric, Parameter, PipelineRegMarkerExtra, TypeExpression, TypeSpec,
17    UnitHead, UnitKind, WalTrace, WhereClause,
18};
19use itertools::Itertools;
20use method_resolution::{FunctionLikeName, IntoImplTarget};
21use num::{BigInt, Zero};
22use serde::{Deserialize, Serialize};
23use spade_common::id_tracker::{ExprID, ImplID};
24use spade_common::num_ext::InfallibleToBigInt;
25use spade_diagnostics::{diag_anyhow, diag_bail, Diagnostic};
26use spade_macros::trace_typechecker;
27use spade_types::meta_types::{unify_meta, MetaType};
28use trace_stack::TraceStack;
29use tracing::{info, trace};
30
31use spade_common::location_info::{Loc, WithLocation};
32use spade_common::name::{Identifier, NameID, Path};
33use spade_hir::param_util::{match_args_with_params, Argument};
34use spade_hir::symbol_table::{Patternable, PatternableKind, SymbolTable, TypeSymbol};
35use spade_hir::{self as hir, ConstGenericWithId, ImplTarget};
36use spade_hir::{
37    ArgumentList, Block, ExprKind, Expression, ItemList, Pattern, PatternArgument, Register,
38    Statement, TraitName, TraitSpec, TypeParam, Unit,
39};
40use spade_types::KnownType;
41
42use constraints::{
43    bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintRhs, ConstraintSource, TypeConstraints,
44};
45use equation::{TraitList, TraitReq, TypeEquations, TypeVar, TypedExpression};
46use error::{
47    error_pattern_type_mismatch, Result, UnificationError, UnificationErrorExt, UnificationTrace,
48};
49use fixed_types::{t_bool, t_clock, t_int, t_uint};
50use requirements::{Replacement, Requirement};
51use trace_stack::{format_trace_stack, TraceStackEntry};
52
53use crate::error::TypeMismatch as Tm;
54use crate::requirements::ConstantInt;
55use crate::traits::{TraitImpl, TraitImplList};
56
57mod constraints;
58pub mod dump;
59pub mod equation;
60pub mod error;
61pub mod expression;
62pub mod fixed_types;
63pub mod method_resolution;
64pub mod mir_type_lowering;
65mod requirements;
66pub mod testutil;
67pub mod trace_stack;
68pub mod traits;
69
70pub struct Context<'a> {
71    pub symtab: &'a SymbolTable,
72    pub items: &'a ItemList,
73    pub trait_impls: &'a TraitImplList,
74}
75
76// NOTE(allow) This is a debug macro which is not normally used but can come in handy
77#[allow(unused_macros)]
78macro_rules! add_trace {
79    ($self:expr, $($arg : tt) *) => {
80        $self.trace_stack.push(TraceStack::Message(format!($($arg)*)))
81    }
82}
83
84#[derive(Debug)]
85pub enum GenericListSource<'a> {
86    /// For when you just need a new generic list but have no need to refer back
87    /// to it in the future
88    Anonymous,
89    /// For managing the generics of callable with the specified name when type checking
90    /// their body
91    Definition(&'a NameID),
92    ImplBlock {
93        target: &'a ImplTarget,
94        id: ImplID,
95    },
96    /// For expressions which instantiate generic items
97    Expression(ExprID),
98}
99
100/// Stored version of GenericListSource
101#[derive(Clone, Hash, Eq, PartialEq, Debug)]
102pub enum GenericListToken {
103    Anonymous(usize),
104    Definition(NameID),
105    ImplBlock(ImplTarget, ImplID),
106    Expression(ExprID),
107}
108
109pub struct TurbofishCtx<'a> {
110    turbofish: &'a Loc<ArgumentList<TypeExpression>>,
111    prev_generic_list: &'a GenericListToken,
112    type_ctx: &'a Context<'a>,
113}
114
115#[derive(Clone)]
116pub struct PipelineState {
117    current_stage_depth: TypeVar,
118    total_depth: Loc<TypeVar>,
119    pipeline_loc: Loc<()>,
120}
121
122/// State of the type inference algorithm
123#[derive(Clone)]
124pub struct TypeState {
125    equations: TypeEquations,
126    next_typeid: u64,
127    // List of the mapping between generic parameters and type vars.
128    // The key is the index of the expression for which this generic list is associated. (if this
129    // is a generic list for a call whose expression id is x to f<A, B>, then generic_lists[x] will
130    // be {A: <type var>, b: <type var>}
131    // Managed here because unification must update *all* TypeVars in existence.
132    generic_lists: HashMap<GenericListToken, HashMap<NameID, TypeVar>>,
133
134    constraints: TypeConstraints,
135
136    /// Requirements which must be fulfilled but which do not guide further type inference.
137    /// For example, if seeing `let y = x.a` before knowing the type of `x`, a requirement is
138    /// added to say "x has field a, and y should be the type of that field"
139    requirements: Vec<Requirement>,
140
141    replacements: HashMap<TypeVar, TypeVar>,
142
143    /// The type var containing the depth of the pipeline stage we are currently in
144    pipeline_state: Option<PipelineState>,
145
146    pub trait_impls: TraitImplList,
147
148    pub trace_stack: Arc<TraceStack>,
149
150    /// (Experimental) Use Affine- or Interval-Arithmetic to bounds check integers in a separate
151    /// module.
152    pub use_wordlenght_inference: bool,
153}
154
155impl Default for TypeState {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl TypeState {
162    pub fn new() -> Self {
163        Self {
164            equations: HashMap::new(),
165            next_typeid: 0,
166            trace_stack: Arc::new(TraceStack::new()),
167            constraints: TypeConstraints::new(),
168            requirements: vec![],
169            replacements: HashMap::new(),
170            generic_lists: HashMap::new(),
171            trait_impls: TraitImplList::new(),
172            pipeline_state: None,
173            use_wordlenght_inference: false,
174        }
175    }
176
177    pub fn set_wordlength_inferece(self, use_wordlenght_inference: bool) -> Self {
178        Self {
179            use_wordlenght_inference,
180            ..self
181        }
182    }
183
184    pub fn get_equations(&self) -> &TypeEquations {
185        &self.equations
186    }
187
188    pub fn get_constraints(&self) -> &TypeConstraints {
189        &self.constraints
190    }
191
192    // Get a generic list with a safe unwrap since a token is acquired
193    pub fn get_generic_list<'a>(
194        &'a self,
195        generic_list_token: &'a GenericListToken,
196    ) -> &'a HashMap<NameID, TypeVar> {
197        &self.generic_lists[generic_list_token]
198    }
199
200    #[tracing::instrument(level = "trace", skip_all)]
201    fn hir_type_expr_to_var<'a>(
202        &'a mut self,
203        e: &Loc<hir::TypeExpression>,
204        generic_list_token: &GenericListToken,
205    ) -> Result<TypeVar> {
206        let tvar = match &e.inner {
207            hir::TypeExpression::Integer(i) => {
208                TypeVar::Known(e.loc(), KnownType::Integer(i.clone()), vec![])
209            }
210            hir::TypeExpression::TypeSpec(spec) => {
211                self.type_var_from_hir(e.loc(), &spec.clone(), generic_list_token)?
212            }
213            hir::TypeExpression::ConstGeneric(g) => {
214                let constraint = self.visit_const_generic(g, generic_list_token)?;
215
216                let tvar = self.new_generic_tlnumber(e.loc());
217                self.add_constraint(
218                    tvar.clone(),
219                    constraint,
220                    g.loc(),
221                    &tvar,
222                    ConstraintSource::Where,
223                );
224
225                tvar
226            }
227        };
228
229        Ok(tvar)
230    }
231
232    #[tracing::instrument(level = "trace", skip_all, fields(%hir_type))]
233    pub fn type_var_from_hir<'a>(
234        &'a mut self,
235        loc: Loc<()>,
236        hir_type: &crate::hir::TypeSpec,
237        generic_list_token: &GenericListToken,
238    ) -> Result<TypeVar> {
239        let generic_list = self.get_generic_list(generic_list_token);
240        match &hir_type {
241            hir::TypeSpec::Declared(base, params) => {
242                let params = params
243                    .iter()
244                    .map(|e| self.hir_type_expr_to_var(e, generic_list_token))
245                    .collect::<Result<Vec<_>>>()?;
246
247                Ok(TypeVar::Known(
248                    loc,
249                    KnownType::Named(base.inner.clone()),
250                    params,
251                ))
252            }
253            hir::TypeSpec::Generic(name) => match generic_list.get(&name.inner) {
254                Some(t) => Ok(t.clone()),
255                None => {
256                    for list_source in self.generic_lists.keys() {
257                        info!("Generic lists exist for {list_source:?}");
258                    }
259                    info!("Current source is {generic_list_token:?}");
260                    panic!("No entry in generic list for {name}");
261                }
262            },
263            hir::TypeSpec::Tuple(inner) => {
264                let inner = inner
265                    .iter()
266                    .map(|t| self.type_var_from_hir(loc, t, generic_list_token))
267                    .collect::<Result<_>>()?;
268                Ok(TypeVar::tuple(loc, inner))
269            }
270            hir::TypeSpec::Array { inner, size } => {
271                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
272                let size = self.hir_type_expr_to_var(size, generic_list_token)?;
273
274                Ok(TypeVar::array(loc, inner, size))
275            }
276            hir::TypeSpec::Wire(inner) => Ok(TypeVar::wire(
277                loc,
278                self.type_var_from_hir(loc, inner, generic_list_token)?,
279            )),
280            hir::TypeSpec::Inverted(inner) => Ok(TypeVar::inverted(
281                loc,
282                self.type_var_from_hir(loc, inner, generic_list_token)?,
283            )),
284            hir::TypeSpec::Wildcard(_) => Ok(self.new_generic_any()),
285            hir::TypeSpec::TraitSelf(_) => {
286                panic!("Trying to convert TraitSelf to type inference type var")
287            }
288        }
289    }
290
291    /// Returns the type of the expression with the specified id. Error if no equation
292    /// for the specified expression exists
293    pub fn type_of(&self, expr: &TypedExpression) -> Result<TypeVar> {
294        for (e, t) in &self.equations {
295            if e == expr {
296                return Ok(t.clone());
297            }
298        }
299        panic!("Tried looking up the type of {expr:?} but it was not found")
300    }
301
302    pub fn new_generic_int(&mut self, loc: Loc<()>, symtab: &SymbolTable) -> TypeVar {
303        TypeVar::Known(loc, t_int(symtab), vec![self.new_generic_tluint(loc)])
304    }
305
306    /// Return a new generic int. The first returned value is int<N>, and the second
307    /// value is N
308    pub fn new_split_generic_int(
309        &mut self,
310        loc: Loc<()>,
311        symtab: &SymbolTable,
312    ) -> (TypeVar, TypeVar) {
313        let size = self.new_generic_tlint(loc);
314        let full = TypeVar::Known(loc, t_int(symtab), vec![size.clone()]);
315        (full, size)
316    }
317
318    pub fn new_split_generic_uint(
319        &mut self,
320        loc: Loc<()>,
321        symtab: &SymbolTable,
322    ) -> (TypeVar, TypeVar) {
323        let size = self.new_generic_tluint(loc);
324        let full = TypeVar::Known(loc, t_uint(symtab), vec![size.clone()]);
325        (full, size)
326    }
327
328    pub fn new_generic_with_meta(&mut self, loc: Loc<()>, meta: MetaType) -> TypeVar {
329        let id = self.new_typeid();
330        TypeVar::Unknown(loc, id, TraitList::empty(), meta)
331    }
332
333    pub fn new_generic_type(&mut self, loc: Loc<()>) -> TypeVar {
334        let id = self.new_typeid();
335        TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Type)
336    }
337
338    pub fn new_generic_any(&mut self) -> TypeVar {
339        let id = self.new_typeid();
340        // NOTE: ().nowhere() is fine here, we can never fail to unify with MetaType::Any so
341        // this loc will never be displayed
342        TypeVar::Unknown(().nowhere(), id, TraitList::empty(), MetaType::Any)
343    }
344
345    pub fn new_generic_tlbool(&mut self, loc: Loc<()>) -> TypeVar {
346        let id = self.new_typeid();
347        TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Bool)
348    }
349
350    pub fn new_generic_tluint(&mut self, loc: Loc<()>) -> TypeVar {
351        let id = self.new_typeid();
352        TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Uint)
353    }
354
355    pub fn new_generic_tlint(&mut self, loc: Loc<()>) -> TypeVar {
356        let id = self.new_typeid();
357        TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Int)
358    }
359
360    pub fn new_generic_tlnumber(&mut self, loc: Loc<()>) -> TypeVar {
361        let id = self.new_typeid();
362        TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Number)
363    }
364
365    pub fn new_generic_number(&mut self, loc: Loc<()>, ctx: &Context) -> (TypeVar, TypeVar) {
366        let number = ctx
367            .symtab
368            .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
369            .expect("Did not find number in symtab")
370            .0;
371        let id = self.new_typeid();
372        let size = self.new_generic_tluint(loc);
373        let t = TraitReq {
374            name: TraitName::Named(number.nowhere()),
375            type_params: vec![size.clone()],
376        }
377        .nowhere();
378        (
379            TypeVar::Unknown(loc, id, TraitList::from_vec(vec![t]), MetaType::Type),
380            size,
381        )
382    }
383
384    pub fn new_generic_with_traits(&mut self, loc: Loc<()>, traits: TraitList) -> TypeVar {
385        let id = self.new_typeid();
386        TypeVar::Unknown(loc, id, traits, MetaType::Type)
387    }
388
389    /// Returns the current pipeline state. `access_loc` is *not* used to specify where to get the
390    /// state from, it is only used for the Diagnostic::bug that gets emitted if there is no
391    /// pipeline state
392    pub fn get_pipeline_state<T>(&self, access_loc: &Loc<T>) -> Result<&PipelineState> {
393        self.pipeline_state
394            .as_ref()
395            .ok_or_else(|| diag_anyhow!(access_loc, "Expected to have a pipeline state"))
396    }
397
398    #[trace_typechecker]
399    #[tracing::instrument(level = "trace", skip_all, fields(%entity.name))]
400    pub fn visit_unit(&mut self, entity: &Loc<Unit>, ctx: &Context) -> Result<()> {
401        self.trait_impls = ctx.trait_impls.clone();
402
403        let generic_list = self.create_generic_list(
404            GenericListSource::Definition(&entity.name.name_id().inner),
405            &entity.head.unit_type_params,
406            &entity.head.scope_type_params,
407            None,
408            // NOTE: I'm not 100% sure we need to pass these here, the information
409            // is probably redundant
410            &entity.head.where_clauses,
411        )?;
412
413        // Add equations for the inputs
414        for (name, t) in &entity.inputs {
415            let tvar = self.type_var_from_hir(t.loc(), t, &generic_list)?;
416            self.add_equation(TypedExpression::Name(name.inner.clone()), tvar)
417        }
418
419        if let UnitKind::Pipeline {
420            depth,
421            depth_typeexpr_id,
422        } = &entity.head.unit_kind.inner
423        {
424            let depth_var = self.hir_type_expr_to_var(depth, &generic_list)?;
425            self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth_var.clone());
426            self.pipeline_state = Some(PipelineState {
427                current_stage_depth: TypeVar::Known(
428                    entity.head.unit_kind.loc(),
429                    KnownType::Integer(BigInt::zero()),
430                    vec![],
431                ),
432                pipeline_loc: entity.loc(),
433                total_depth: depth_var.clone().at_loc(depth),
434            });
435            self.add_requirement(Requirement::PositivePipelineDepth {
436                depth: depth_var.at_loc(depth),
437            });
438            self.unify(
439                &TypedExpression::Name(entity.inputs[0].0.clone().inner),
440                &t_clock(ctx.symtab).at_loc(&entity.head.unit_kind),
441                ctx,
442            )
443            .into_diagnostic(
444                entity.inputs[0].1.loc(),
445                |diag,
446                 Tm {
447                     g: got,
448                     e: _expected,
449                 }| {
450                    diag.message(format!(
451                        "First pipeline argument must be a clock. Got {got}"
452                    ))
453                    .primary_label("expected clock")
454                },
455            )?;
456            // In order to catch negative depth early when they are specified as literals,
457            // we'll instantly check requirements here
458            self.check_requirements(ctx)?;
459        }
460
461        self.visit_expression(&entity.body, ctx, &generic_list)?;
462
463        // Ensure that the output type matches what the user specified, and unit otherwise
464        if let Some(output_type) = &entity.head.output_type {
465            let tvar = self.type_var_from_hir(output_type.loc(), output_type, &generic_list)?;
466
467            self.trace_stack.push(TraceStackEntry::Message(format!(
468                "Unifying with output type {tvar:?}"
469            )));
470            self.unify(&TypedExpression::Id(entity.body.inner.id), &tvar, ctx)
471                .into_diagnostic_no_expected_source(
472                    &entity.body,
473                    |diag,
474                     Tm {
475                         g: got,
476                         e: expected,
477                     }| {
478                        // FIXME: Might want to check if there is no return value in the block
479                        // and in that case suggest adding it.
480                        diag.message(format!(
481                            "Output type mismatch. Expected {expected}, got {got}"
482                        ))
483                        .primary_label(format!("Found type {got}"))
484                        .secondary_label(output_type, format!("{expected} type specified here"))
485                    },
486                )?;
487        } else {
488            // No output type, so unify with the unit type.
489            self.unify(
490                &TypedExpression::Id(entity.body.inner.id),
491                &TypeVar::unit(entity.head.name.loc()),
492                ctx
493            )
494            .into_diagnostic_no_expected_source(entity.body.loc(), |diag, Tm{g: got, e: _expected}| {
495                diag.message("Output type mismatch")
496                    .primary_label(format!("Found type {got}"))
497                    .note(format!(
498                        "The {} does not specify a return type.\nAdd a return type, or remove the return value.",
499                        entity.head.unit_kind.name()
500                    ))
501            })?;
502        }
503
504        if let Some(PipelineState {
505            current_stage_depth,
506            pipeline_loc,
507            total_depth,
508        }) = self.pipeline_state.clone()
509        {
510            self.unify(&total_depth.inner, &current_stage_depth, ctx)
511                .into_diagnostic_no_expected_source(pipeline_loc, |diag, tm| {
512                    let (e, g) = tm.display_e_g();
513                    diag.message(format!("Pipeline depth mismatch. Expected {g} got {e}"))
514                        .primary_label(format!("Found {e} stages in this pipeline"))
515                })?;
516        }
517
518        self.check_requirements(ctx)?;
519
520        // NOTE: We may accidentally leak a stage depth if this function returns early. However,
521        // since we only use stage depths in pipelines where we re-set it when we enter,
522        // accidentally leaving a depth in place should not trigger any *new* compiler bugs, but
523        // may help track something down
524        self.pipeline_state = None;
525
526        Ok(())
527    }
528
529    #[trace_typechecker]
530    #[tracing::instrument(level = "trace", skip_all)]
531    fn visit_argument_list(
532        &mut self,
533        args: &Loc<ArgumentList<Expression>>,
534        ctx: &Context,
535        generic_list: &GenericListToken,
536    ) -> Result<()> {
537        for expr in args.expressions() {
538            self.visit_expression(expr, ctx, generic_list)?;
539        }
540        Ok(())
541    }
542
543    #[trace_typechecker]
544    fn type_check_argument_list(
545        &mut self,
546        args: &[Argument<Expression, TypeSpec>],
547        ctx: &Context,
548        generic_list: &GenericListToken,
549    ) -> Result<()> {
550        for Argument {
551            target,
552            target_type,
553            value,
554            kind,
555        } in args.iter()
556        {
557            let target_type = self.type_var_from_hir(value.loc(), target_type, generic_list)?;
558
559            let loc = match kind {
560                hir::param_util::ArgumentKind::Positional => value.loc(),
561                hir::param_util::ArgumentKind::Named => value.loc(),
562                hir::param_util::ArgumentKind::ShortNamed => target.loc(),
563            };
564
565            self.unify(&value.inner, &target_type, ctx)
566                .into_diagnostic(
567                    loc,
568                    |d,
569                     Tm {
570                         e: expected,
571                         g: got,
572                     }| {
573                        d.message(format!(
574                            "Argument type mismatch. Expected {expected} got {got}"
575                        ))
576                        .primary_label(format!("expected {expected}"))
577                    },
578                )?;
579        }
580
581        Ok(())
582    }
583
584    #[tracing::instrument(level = "trace", skip_all)]
585    #[trace_typechecker]
586    pub fn visit_expression(
587        &mut self,
588        expression: &Loc<Expression>,
589        ctx: &Context,
590        generic_list: &GenericListToken,
591    ) -> Result<()> {
592        let new_type = self.new_generic_type(expression.loc());
593        self.add_equation(TypedExpression::Id(expression.inner.id), new_type);
594
595        // Recurse down the expression
596        match &expression.inner.kind {
597            ExprKind::Identifier(_) => self.visit_identifier(expression, ctx)?,
598            ExprKind::TypeLevelInteger(_) => {
599                self.visit_type_level_integer(expression, generic_list, ctx)?
600            }
601            ExprKind::IntLiteral(_, _) => self.visit_int_literal(expression, ctx)?,
602            ExprKind::BoolLiteral(_) => self.visit_bool_literal(expression, ctx)?,
603            ExprKind::BitLiteral(_) => self.visit_bit_literal(expression, ctx)?,
604            ExprKind::TupleLiteral(_) => self.visit_tuple_literal(expression, ctx, generic_list)?,
605            ExprKind::TupleIndex(_, _) => self.visit_tuple_index(expression, ctx, generic_list)?,
606            ExprKind::ArrayLiteral(_) => self.visit_array_literal(expression, ctx, generic_list)?,
607            ExprKind::ArrayShorthandLiteral(_, _) => {
608                self.visit_array_shorthand_literal(expression, ctx, generic_list)?
609            }
610            ExprKind::CreatePorts => self.visit_create_ports(expression, ctx, generic_list)?,
611            ExprKind::FieldAccess(_, _) => {
612                self.visit_field_access(expression, ctx, generic_list)?
613            }
614            ExprKind::MethodCall { .. } => self.visit_method_call(expression, ctx, generic_list)?,
615            ExprKind::Index(_, _) => self.visit_index(expression, ctx, generic_list)?,
616            ExprKind::RangeIndex { .. } => self.visit_range_index(expression, ctx, generic_list)?,
617            ExprKind::Block(_) => self.visit_block_expr(expression, ctx, generic_list)?,
618            ExprKind::If(_, _, _) => self.visit_if(expression, ctx, generic_list)?,
619            ExprKind::Match(_, _) => self.visit_match(expression, ctx, generic_list)?,
620            ExprKind::BinaryOperator(_, _, _) => {
621                self.visit_binary_operator(expression, ctx, generic_list)?
622            }
623            ExprKind::UnaryOperator(_, _) => {
624                self.visit_unary_operator(expression, ctx, generic_list)?
625            }
626            ExprKind::Call {
627                kind,
628                callee,
629                args,
630                turbofish,
631            } => {
632                let head = ctx.symtab.unit_by_id(&callee.inner);
633
634                self.handle_function_like(
635                    expression.map_ref(|e| e.id),
636                    &expression.get_type(self)?,
637                    &FunctionLikeName::Free(callee.inner.clone()),
638                    &head,
639                    kind,
640                    args,
641                    ctx,
642                    true,
643                    false,
644                    turbofish.as_ref().map(|turbofish| TurbofishCtx {
645                        turbofish,
646                        prev_generic_list: generic_list,
647                        type_ctx: ctx,
648                    }),
649                    generic_list,
650                )?;
651            }
652            ExprKind::PipelineRef { .. } => {
653                self.visit_pipeline_ref(expression, generic_list, ctx)?;
654            }
655            ExprKind::StageReady | ExprKind::StageValid => {
656                self.unify_expression_generic_error(
657                    expression,
658                    &t_bool(ctx.symtab).at_loc(expression),
659                    ctx,
660                )?;
661            }
662            ExprKind::TypeLevelIf(cond, on_true, on_false) => {
663                let cond_var = self.visit_const_generic_with_id(
664                    cond,
665                    generic_list,
666                    ConstraintSource::TypeLevelIf,
667                    ctx,
668                )?;
669                let t_bool = self.new_generic_tlbool(cond.loc());
670                self.unify(&cond_var, &t_bool, ctx).into_diagnostic(
671                    cond,
672                    |diag, Tm { e: _, g }| {
673                        diag.message(format!("gen if conditions must be #bool, got {g}"))
674                    },
675                )?;
676
677                self.visit_expression(on_true, ctx, generic_list)?;
678                self.visit_expression(on_false, ctx, generic_list)?;
679
680                self.unify_expression_generic_error(expression, on_true.as_ref(), ctx)?;
681                self.unify_expression_generic_error(expression, on_false.as_ref(), ctx)?;
682            }
683            ExprKind::Null => {}
684        }
685        Ok(())
686    }
687
688    // Common handler for entities, functions and pipelines
689    #[tracing::instrument(level = "trace", skip_all, fields(%name))]
690    #[trace_typechecker]
691    fn handle_function_like(
692        &mut self,
693        expression_id: Loc<ExprID>,
694        expression_type: &TypeVar,
695        name: &FunctionLikeName,
696        head: &Loc<UnitHead>,
697        call_kind: &CallKind,
698        args: &Loc<ArgumentList<Expression>>,
699        ctx: &Context,
700        // Whether or not to visit the argument expressions passed to the function here. This
701        // should not be done if the expressoins have been visited before, for example, when
702        // handling methods
703        visit_args: bool,
704        // If we are calling a method, we have an implicit self argument which means
705        // that any error reporting number of arguments should be reduced by one
706        is_method: bool,
707        turbofish: Option<TurbofishCtx>,
708        generic_list: &GenericListToken,
709    ) -> Result<()> {
710        // Add new symbols for all the type parameters
711        let unit_generic_list = self.create_generic_list(
712            GenericListSource::Expression(expression_id.inner),
713            &head.unit_type_params,
714            &head.scope_type_params,
715            turbofish,
716            &head.where_clauses,
717        )?;
718
719        match (&head.unit_kind.inner, call_kind) {
720            (
721                UnitKind::Pipeline {
722                    depth: udepth,
723                    depth_typeexpr_id: _,
724                },
725                CallKind::Pipeline {
726                    inst_loc: _,
727                    depth: cdepth,
728                    depth_typeexpr_id: cdepth_typeexpr_id,
729                },
730            ) => {
731                let definition_depth = self.hir_type_expr_to_var(udepth, &unit_generic_list)?;
732                let call_depth = self.hir_type_expr_to_var(cdepth, generic_list)?;
733
734                // NOTE: We're not adding udepth_typeexpr_id here as that would break
735                // in the future if we try to do recursion. We will also never need to look
736                // up the depth in the definition
737                self.add_equation(TypedExpression::Id(*cdepth_typeexpr_id), call_depth.clone());
738
739                self.unify(&call_depth, &definition_depth, ctx)
740                    .into_diagnostic_no_expected_source(cdepth, |diag, Tm { e, g }| {
741                        diag.message("Pipeline depth mismatch")
742                            .primary_label(format!("Expected depth {e}, got {g}"))
743                            .secondary_label(udepth, format!("{name} has depth {e}"))
744                    })?;
745            }
746            _ => {}
747        }
748
749        if visit_args {
750            self.visit_argument_list(args, ctx, &generic_list)?;
751        }
752
753        let type_params = &head.get_type_params();
754
755        // Special handling of built in functions
756        macro_rules! handle_special_functions {
757            ($([$($path:expr),*] => $handler:expr),*) => {
758                $(
759                    let path = Path(vec![$(Identifier($path.to_string()).nowhere()),*]).nowhere();
760                    if ctx.symtab
761                        .try_lookup_final_id(&path, &[])
762                        .map(|n| &FunctionLikeName::Free(n) == name)
763                        .unwrap_or(false)
764                    {
765                        $handler
766                    };
767                )*
768            }
769        }
770
771        // NOTE: These would be better as a closure, but unfortunately that does not satisfy
772        // the borrow checker
773        macro_rules! generic_arg {
774            ($idx:expr) => {
775                self.get_generic_list(&unit_generic_list)[&type_params[$idx].name_id()].clone()
776            };
777        }
778
779        let matched_args =
780            match_args_with_params(args, &head.inputs.inner, is_method).map_err(|e| {
781                let diag: Diagnostic = e.into();
782                diag.secondary_label(
783                    head,
784                    format!("{kind} defined here", kind = head.unit_kind.name()),
785                )
786            })?;
787
788        handle_special_functions! {
789            ["std", "conv", "concat"] => {
790                self.handle_concat(
791                    expression_id,
792                    generic_arg!(0),
793                    generic_arg!(1),
794                    generic_arg!(2),
795                    &matched_args,
796                    ctx
797                )?
798            },
799            ["std", "conv", "trunc"] => {
800                self.handle_trunc(
801                    expression_id,
802                    generic_arg!(0),
803                    generic_arg!(1),
804                    &matched_args,
805                    ctx
806                )?
807            },
808            ["std", "ops", "comb_div"] => {
809                self.handle_comb_mod_or_div(
810                    generic_arg!(0),
811                    &matched_args,
812                    ctx
813                )?
814            },
815            ["std", "ops", "comb_mod"] => {
816                self.handle_comb_mod_or_div(
817                    generic_arg!(0),
818                    &matched_args,
819                    ctx
820                )?
821            },
822            ["std", "mem", "clocked_memory"]  => {
823                let num_elements = generic_arg!(0);
824                let addr_size = generic_arg!(2);
825
826                self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
827            },
828            // NOTE: the last argument of _init does not need special treatment, so
829            // we can use the same code path for both for now
830            ["std", "mem", "clocked_memory_init"]  => {
831                let num_elements = generic_arg!(0);
832                let addr_size = generic_arg!(2);
833
834                self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
835            },
836            ["std", "mem", "read_memory"]  => {
837                let addr_size = generic_arg!(0);
838                let num_elements = generic_arg!(2);
839
840                self.handle_read_memory(num_elements, addr_size, &matched_args, ctx)?
841            }
842        };
843
844        // Unify the types of the arguments
845        self.type_check_argument_list(&matched_args, ctx, &unit_generic_list)?;
846
847        let return_type = head
848            .output_type
849            .as_ref()
850            .map(|o| self.type_var_from_hir(expression_id.loc(), o, &unit_generic_list))
851            .transpose()?
852            .unwrap_or_else(|| TypeVar::Known(expression_id.loc(), KnownType::Tuple, vec![]));
853
854        self.unify(expression_type, &return_type, ctx)
855            .into_default_diagnostic(expression_id.loc())?;
856
857        Ok(())
858    }
859
860    pub fn handle_concat(
861        &mut self,
862        expression_id: Loc<ExprID>,
863        source_lhs_ty: TypeVar,
864        source_rhs_ty: TypeVar,
865        source_result_ty: TypeVar,
866        args: &[Argument<Expression, TypeSpec>],
867        ctx: &Context,
868    ) -> Result<()> {
869        let (lhs_type, lhs_size) = self.new_generic_number(expression_id.loc(), ctx);
870        let (rhs_type, rhs_size) = self.new_generic_number(expression_id.loc(), ctx);
871        let (result_type, result_size) = self.new_generic_number(expression_id.loc(), ctx);
872        self.unify(&source_lhs_ty, &lhs_type, ctx)
873            .into_default_diagnostic(args[0].value.loc())?;
874        self.unify(&source_rhs_ty, &rhs_type, ctx)
875            .into_default_diagnostic(args[1].value.loc())?;
876        self.unify(&source_result_ty, &result_type, ctx)
877            .into_default_diagnostic(expression_id.loc())?;
878
879        // Result size is sum of input sizes
880        self.add_constraint(
881            result_size.clone(),
882            ce_var(&lhs_size) + ce_var(&rhs_size),
883            expression_id.loc(),
884            &result_size,
885            ConstraintSource::Concatenation,
886        );
887        self.add_constraint(
888            lhs_size.clone(),
889            ce_var(&result_size) + -ce_var(&rhs_size),
890            args[0].value.loc(),
891            &lhs_size,
892            ConstraintSource::Concatenation,
893        );
894        self.add_constraint(
895            rhs_size.clone(),
896            ce_var(&result_size) + -ce_var(&lhs_size),
897            args[1].value.loc(),
898            &rhs_size,
899            ConstraintSource::Concatenation,
900        );
901
902        self.add_requirement(Requirement::SharedBase(vec![
903            lhs_type.at_loc(args[0].value),
904            rhs_type.at_loc(args[1].value),
905            result_type.at_loc(&expression_id.loc()),
906        ]));
907        Ok(())
908    }
909
910    pub fn handle_trunc(
911        &mut self,
912        expression_id: Loc<ExprID>,
913        source_in_ty: TypeVar,
914        source_result_ty: TypeVar,
915        args: &[Argument<Expression, TypeSpec>],
916        ctx: &Context,
917    ) -> Result<()> {
918        let (in_ty, _) = self.new_generic_number(expression_id.loc(), ctx);
919        let (result_type, _) = self.new_generic_number(expression_id.loc(), ctx);
920        self.unify(&source_in_ty, &in_ty, ctx)
921            .into_default_diagnostic(args[0].value.loc())?;
922        self.unify(&source_result_ty, &result_type, ctx)
923            .into_default_diagnostic(expression_id.loc())?;
924
925        self.add_requirement(Requirement::SharedBase(vec![
926            in_ty.at_loc(args[0].value),
927            result_type.at_loc(&expression_id.loc()),
928        ]));
929        Ok(())
930    }
931
932    pub fn handle_comb_mod_or_div(
933        &mut self,
934        n_ty: TypeVar,
935        args: &[Argument<Expression, TypeSpec>],
936        ctx: &Context,
937    ) -> Result<()> {
938        let (num, _) = self.new_generic_number(args[0].value.loc(), ctx);
939        self.unify(&n_ty, &num, ctx)
940            .into_default_diagnostic(args[0].value.loc())?;
941        Ok(())
942    }
943
944    pub fn handle_clocked_memory(
945        &mut self,
946        num_elements: TypeVar,
947        addr_size_arg: TypeVar,
948        args: &[Argument<Expression, TypeSpec>],
949        ctx: &Context,
950    ) -> Result<()> {
951        // FIXME: When we support where clauses, we should move this
952        // definition into the definition of clocked_memory
953        let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
954        let arg1_loc = args[1].value.loc();
955        let port_type = TypeVar::array(
956            arg1_loc,
957            TypeVar::tuple(
958                args[1].value.loc(),
959                vec![
960                    self.new_generic_type(arg1_loc),
961                    addr_type,
962                    self.new_generic_type(arg1_loc),
963                ],
964            ),
965            self.new_generic_tluint(arg1_loc),
966        );
967
968        self.add_constraint(
969            addr_size.clone(),
970            bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
971            args[1].value.loc(),
972            &port_type,
973            ConstraintSource::MemoryIndexing,
974        );
975
976        // NOTE: Unwrap is safe, size is still generic at this point
977        self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
978        self.unify_expression_generic_error(args[1].value, &port_type, ctx)?;
979
980        Ok(())
981    }
982
983    pub fn handle_read_memory(
984        &mut self,
985        num_elements: TypeVar,
986        addr_size_arg: TypeVar,
987        args: &[Argument<Expression, TypeSpec>],
988        ctx: &Context,
989    ) -> Result<()> {
990        let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
991
992        self.add_constraint(
993            addr_size.clone(),
994            bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
995            args[1].value.loc(),
996            &addr_type,
997            ConstraintSource::MemoryIndexing,
998        );
999
1000        // NOTE: Unwrap is safe, size is still generic at this point
1001        self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1002
1003        Ok(())
1004    }
1005
1006    #[tracing::instrument(level = "trace", skip(self, turbofish, where_clauses))]
1007    pub fn create_generic_list(
1008        &mut self,
1009        source: GenericListSource,
1010        type_params: &[Loc<TypeParam>],
1011        scope_type_params: &[Loc<TypeParam>],
1012        turbofish: Option<TurbofishCtx>,
1013        where_clauses: &[Loc<WhereClause>],
1014    ) -> Result<GenericListToken> {
1015        let turbofish_params = if let Some(turbofish) = turbofish.as_ref() {
1016            if type_params.is_empty() {
1017                return Err(Diagnostic::error(
1018                    turbofish.turbofish,
1019                    "Turbofish on non-generic function",
1020                )
1021                .primary_label("Turbofish on non-generic function"));
1022            }
1023
1024            let matched_params =
1025                param_util::match_args_with_params(turbofish.turbofish, &type_params, false)?;
1026
1027            // We want this to be positional, but the params we get from matching are
1028            // named, transform it. We'll do some unwrapping here, but it is safe
1029            // because we know all params are present
1030            matched_params
1031                .iter()
1032                .map(|matched_param| {
1033                    let i = type_params
1034                        .iter()
1035                        .enumerate()
1036                        .find_map(|(i, param)| match &param.inner {
1037                            TypeParam {
1038                                ident,
1039                                name_id: _,
1040                                trait_bounds: _,
1041                                meta: _,
1042                            } => {
1043                                if ident == matched_param.target {
1044                                    Some(i)
1045                                } else {
1046                                    None
1047                                }
1048                            }
1049                        })
1050                        .unwrap();
1051                    (i, matched_param)
1052                })
1053                .sorted_by_key(|(i, _)| *i)
1054                .map(|(_, mp)| Some(mp.value))
1055                .collect::<Vec<_>>()
1056        } else {
1057            type_params.iter().map(|_| None).collect::<Vec<_>>()
1058        };
1059
1060        let mut inline_trait_bounds: Vec<Loc<WhereClause>> = vec![];
1061
1062        let scope_type_params = scope_type_params
1063            .iter()
1064            .map(|param| {
1065                let hir::TypeParam {
1066                    ident,
1067                    name_id,
1068                    trait_bounds,
1069                    meta,
1070                } = &param.inner;
1071                if !trait_bounds.is_empty() {
1072                    if let MetaType::Type = meta {
1073                        inline_trait_bounds.push(
1074                            WhereClause::Type {
1075                                target: name_id.clone().at_loc(ident),
1076                                traits: trait_bounds.clone(),
1077                            }
1078                            .at_loc(param),
1079                        );
1080                    } else {
1081                        return Err(Diagnostic::bug(param, "Trait bounds on generic int")
1082                            .primary_label("Trait bounds are only allowed on type parameters"));
1083                    }
1084                }
1085                Ok((
1086                    name_id.clone(),
1087                    self.new_generic_with_meta(param.loc(), meta.clone()),
1088                ))
1089            })
1090            .collect::<Result<Vec<_>>>()?;
1091
1092        let new_list = type_params
1093            .iter()
1094            .enumerate()
1095            .map(|(i, param)| {
1096                let hir::TypeParam {
1097                    ident,
1098                    name_id,
1099                    trait_bounds,
1100                    meta,
1101                } = &param.inner;
1102
1103                let t = self.new_generic_with_meta(param.loc(), meta.clone());
1104
1105                if let Some(tf) = &turbofish_params[i] {
1106                    let tf_ctx = turbofish.as_ref().unwrap();
1107                    let ty = self.hir_type_expr_to_var(tf, tf_ctx.prev_generic_list)?;
1108                    self.unify(&ty, &t, tf_ctx.type_ctx)
1109                        .into_default_diagnostic(param)?;
1110                }
1111
1112                if !trait_bounds.is_empty() {
1113                    if let MetaType::Type = meta {
1114                        inline_trait_bounds.push(
1115                            WhereClause::Type {
1116                                target: name_id.clone().at_loc(ident),
1117                                traits: trait_bounds.clone(),
1118                            }
1119                            .at_loc(param),
1120                        );
1121                    }
1122                    Ok((name_id.clone(), t))
1123                } else {
1124                    Ok((name_id.clone(), self.check_var_for_replacement(t)))
1125                }
1126            })
1127            .collect::<Result<Vec<_>>>()?
1128            .into_iter()
1129            .chain(scope_type_params.into_iter())
1130            .map(|(name, t)| (name, t.clone()))
1131            .collect::<HashMap<_, _>>();
1132
1133        self.trace_stack
1134            .push(TraceStackEntry::NewGenericList(new_list.clone()));
1135
1136        let token = self.add_mapped_generic_list(source, new_list.clone());
1137
1138        for constraint in where_clauses.iter().chain(inline_trait_bounds.iter()) {
1139            match &constraint.inner {
1140                WhereClause::Type { target, traits } => {
1141                    self.visit_trait_bounds(target, traits.as_slice(), &token)?;
1142                }
1143                WhereClause::Int { target, constraint } => {
1144                    let int_constraint = self.visit_const_generic(constraint, &token)?;
1145                    let tvar = new_list.get(target).ok_or_else(|| {
1146                        Diagnostic::error(
1147                            target,
1148                            format!("{target} is not a generic parameter on this unit"),
1149                        )
1150                        .primary_label("Not a generic parameter")
1151                    })?;
1152
1153                    self.add_constraint(
1154                        tvar.clone(),
1155                        int_constraint,
1156                        constraint.loc(),
1157                        &tvar,
1158                        ConstraintSource::Where,
1159                    );
1160                }
1161            }
1162        }
1163
1164        Ok(token)
1165    }
1166
1167    /// Adds a generic list with parameters already mapped to types
1168    pub fn add_mapped_generic_list(
1169        &mut self,
1170        source: GenericListSource,
1171        mapping: HashMap<NameID, TypeVar>,
1172    ) -> GenericListToken {
1173        let reference = match source {
1174            GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1175            GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1176            GenericListSource::ImplBlock { target, id } => {
1177                GenericListToken::ImplBlock(target.clone(), id)
1178            }
1179            GenericListSource::Expression(id) => GenericListToken::Expression(id),
1180        };
1181
1182        if self
1183            .generic_lists
1184            .insert(reference.clone(), mapping)
1185            .is_some()
1186        {
1187            panic!("A generic list already existed for {reference:?}");
1188        }
1189        reference
1190    }
1191
1192    #[tracing::instrument(level = "trace", skip_all)]
1193    #[trace_typechecker]
1194    pub fn visit_block(
1195        &mut self,
1196        block: &Block,
1197        ctx: &Context,
1198        generic_list: &GenericListToken,
1199    ) -> Result<()> {
1200        for statement in &block.statements {
1201            self.visit_statement(statement, ctx, generic_list)?;
1202        }
1203        if let Some(result) = &block.result {
1204            self.visit_expression(result, ctx, generic_list)?;
1205        }
1206        Ok(())
1207    }
1208
1209    #[tracing::instrument(level = "trace", skip_all)]
1210    #[trace_typechecker]
1211    pub fn visit_impl_blocks(&mut self, item_list: &ItemList) -> Result<TraitImplList> {
1212        let mut trait_impls = TraitImplList::new();
1213        for (target, impls) in &item_list.impls {
1214            for ((trait_name, type_expressions), impl_block) in impls {
1215                let generic_list = self.create_generic_list(
1216                    GenericListSource::ImplBlock {
1217                        target,
1218                        id: impl_block.id,
1219                    },
1220                    &[],
1221                    impl_block.type_params.as_slice(),
1222                    None,
1223                    &[],
1224                )?;
1225
1226                let loc = trait_name
1227                    .name_loc()
1228                    .map(|n| ().at_loc(&n))
1229                    .unwrap_or(().at_loc(&impl_block));
1230
1231                let mapped_type_vars = type_expressions
1232                    .iter()
1233                    .map(|param| {
1234                        self.hir_type_expr_to_var(&param.clone().at_loc(&loc), &generic_list)
1235                    })
1236                    .collect::<Result<_>>()?;
1237
1238                trait_impls
1239                    .inner
1240                    .entry(target.clone())
1241                    .or_default()
1242                    .push(TraitImpl {
1243                        name: trait_name.clone(),
1244                        type_params: mapped_type_vars,
1245                        impl_block: impl_block.inner.clone(),
1246                    });
1247            }
1248        }
1249
1250        Ok(trait_impls)
1251    }
1252
1253    #[trace_typechecker]
1254    pub fn visit_pattern(
1255        &mut self,
1256        pattern: &Loc<Pattern>,
1257        ctx: &Context,
1258        generic_list: &GenericListToken,
1259    ) -> Result<()> {
1260        let new_type = self.new_generic_type(pattern.loc());
1261        self.add_equation(TypedExpression::Id(pattern.inner.id), new_type);
1262        match &pattern.inner.kind {
1263            hir::PatternKind::Integer(val) => {
1264                let (num_t, _) = &self.new_generic_number(pattern.loc(), ctx);
1265                self.add_requirement(Requirement::FitsIntLiteral {
1266                    value: ConstantInt::Literal(val.clone()),
1267                    target_type: num_t.clone().at_loc(pattern),
1268                });
1269                self.unify(pattern, num_t, ctx)
1270                    .expect("Failed to unify new_generic with int");
1271            }
1272            hir::PatternKind::Bool(_) => {
1273                self.unify(pattern, &t_bool(ctx.symtab).at_loc(pattern), ctx)
1274                    .expect("Expected new_generic with boolean");
1275            }
1276            hir::PatternKind::Name { name, pre_declared } => {
1277                if !pre_declared {
1278                    self.add_equation(
1279                        TypedExpression::Name(name.clone().inner),
1280                        pattern.get_type(self)?,
1281                    );
1282                }
1283                self.unify(
1284                    &TypedExpression::Id(pattern.id),
1285                    &TypedExpression::Name(name.clone().inner),
1286                    ctx,
1287                )
1288                .into_default_diagnostic(name.loc())?;
1289            }
1290            hir::PatternKind::Tuple(subpatterns) => {
1291                for pattern in subpatterns {
1292                    self.visit_pattern(pattern, ctx, generic_list)?;
1293                }
1294                let tuple_type = TypeVar::tuple(
1295                    pattern.loc(),
1296                    subpatterns
1297                        .iter()
1298                        .map(|pattern| {
1299                            let p_type = pattern.get_type(self)?;
1300                            Ok(p_type)
1301                        })
1302                        .collect::<Result<_>>()?,
1303                );
1304
1305                self.unify(pattern, &tuple_type, ctx)
1306                    .expect("Unification of new_generic with tuple type cannot fail");
1307            }
1308            hir::PatternKind::Array(inner) => {
1309                for pattern in inner {
1310                    self.visit_pattern(pattern, ctx, generic_list)?;
1311                }
1312                if inner.len() == 0 {
1313                    return Err(
1314                        Diagnostic::error(pattern, "Empty array patterns are unsupported")
1315                            .primary_label("Empty array pattern"),
1316                    );
1317                } else {
1318                    let inner_t = inner[0].get_type(self)?;
1319
1320                    for pattern in inner.iter().skip(1) {
1321                        self.unify(pattern, &inner_t, ctx)
1322                            .into_default_diagnostic(pattern)?;
1323                    }
1324
1325                    self.unify(
1326                        pattern,
1327                        &TypeVar::Known(
1328                            pattern.loc(),
1329                            KnownType::Array,
1330                            vec![
1331                                inner_t,
1332                                TypeVar::Known(
1333                                    pattern.loc(),
1334                                    KnownType::Integer(inner.len().to_bigint()),
1335                                    vec![],
1336                                ),
1337                            ],
1338                        ),
1339                        ctx,
1340                    )
1341                    .into_default_diagnostic(pattern)?;
1342                }
1343            }
1344            hir::PatternKind::Type(name, args) => {
1345                let (condition_type, params, generic_list) =
1346                    match ctx.symtab.patternable_type_by_id(name).inner {
1347                        Patternable {
1348                            kind: PatternableKind::Enum,
1349                            params: _,
1350                        } => {
1351                            let enum_variant = ctx.symtab.enum_variant_by_id(name).inner;
1352                            let generic_list = self.create_generic_list(
1353                                GenericListSource::Anonymous,
1354                                &enum_variant.type_params,
1355                                &[],
1356                                None,
1357                                &[],
1358                            )?;
1359
1360                            let condition_type = self.type_var_from_hir(
1361                                pattern.loc(),
1362                                &enum_variant.output_type,
1363                                &generic_list,
1364                            )?;
1365
1366                            (condition_type, enum_variant.params, generic_list)
1367                        }
1368                        Patternable {
1369                            kind: PatternableKind::Struct,
1370                            params: _,
1371                        } => {
1372                            let s = ctx.symtab.struct_by_id(name).inner;
1373                            let generic_list = self.create_generic_list(
1374                                GenericListSource::Anonymous,
1375                                &s.type_params,
1376                                &[],
1377                                None,
1378                                &[],
1379                            )?;
1380
1381                            let condition_type =
1382                                self.type_var_from_hir(pattern.loc(), &s.self_type, &generic_list)?;
1383
1384                            (condition_type, s.params, generic_list)
1385                        }
1386                    };
1387
1388                self.unify(pattern, &condition_type, ctx)
1389                    .expect("Unification of new_generic with enum cannot fail");
1390
1391                for (
1392                    PatternArgument {
1393                        target,
1394                        value: pattern,
1395                        kind,
1396                    },
1397                    Parameter {
1398                        name: _,
1399                        ty: target_type,
1400                        no_mangle: _,
1401                    },
1402                ) in args.iter().zip(params.0.iter())
1403                {
1404                    self.visit_pattern(pattern, ctx, &generic_list)?;
1405                    let target_type =
1406                        self.type_var_from_hir(target_type.loc(), target_type, &generic_list)?;
1407
1408                    let loc = match kind {
1409                        hir::ArgumentKind::Positional => pattern.loc(),
1410                        hir::ArgumentKind::Named => pattern.loc(),
1411                        hir::ArgumentKind::ShortNamed => target.loc(),
1412                    };
1413
1414                    self.unify(pattern, &target_type, ctx).into_diagnostic(
1415                        loc,
1416                        |d,
1417                         Tm {
1418                             e: expected,
1419                             g: got,
1420                         }| {
1421                            d.message(format!(
1422                                "Argument type mismatch. Expected {expected} got {got}"
1423                            ))
1424                            .primary_label(format!("expected {expected}"))
1425                        },
1426                    )?;
1427                }
1428            }
1429        }
1430        Ok(())
1431    }
1432
1433    #[trace_typechecker]
1434    pub fn visit_wal_trace(
1435        &mut self,
1436        trace: &Loc<WalTrace>,
1437        ctx: &Context,
1438        generic_list: &GenericListToken,
1439    ) -> Result<()> {
1440        let WalTrace { clk, rst } = &trace.inner;
1441        clk.as_ref()
1442            .map(|x| {
1443                self.visit_expression(x, ctx, generic_list)?;
1444                self.unify_expression_generic_error(x, &t_clock(ctx.symtab).at_loc(trace), ctx)
1445            })
1446            .transpose()?;
1447        rst.as_ref()
1448            .map(|x| {
1449                self.visit_expression(x, ctx, generic_list)?;
1450                self.unify_expression_generic_error(x, &t_bool(ctx.symtab).at_loc(trace), ctx)
1451            })
1452            .transpose()?;
1453        Ok(())
1454    }
1455
1456    #[trace_typechecker]
1457    pub fn visit_statement(
1458        &mut self,
1459        stmt: &Loc<Statement>,
1460        ctx: &Context,
1461        generic_list: &GenericListToken,
1462    ) -> Result<()> {
1463        match &stmt.inner {
1464            Statement::Binding(Binding {
1465                pattern,
1466                ty,
1467                value,
1468                wal_trace,
1469            }) => {
1470                trace!("Visiting `let {} = ..`", pattern.kind);
1471                self.visit_expression(value, ctx, generic_list)?;
1472
1473                self.visit_pattern(pattern, ctx, generic_list)?;
1474
1475                self.unify(&TypedExpression::Id(pattern.id), value, ctx)
1476                    .into_diagnostic(
1477                        pattern.loc(),
1478                        error_pattern_type_mismatch(
1479                            ty.as_ref().map(|t| t.loc()).unwrap_or_else(|| value.loc()),
1480                        ),
1481                    )?;
1482
1483                if let Some(t) = ty {
1484                    let tvar = self.type_var_from_hir(t.loc(), t, generic_list)?;
1485                    self.unify(&TypedExpression::Id(pattern.id), &tvar, ctx)
1486                        .into_default_diagnostic(value.loc())?;
1487                }
1488
1489                wal_trace
1490                    .as_ref()
1491                    .map(|wt| self.visit_wal_trace(wt, ctx, generic_list))
1492                    .transpose()?;
1493
1494                Ok(())
1495            }
1496            Statement::Expression(expr) => self.visit_expression(expr, ctx, generic_list),
1497            Statement::Register(reg) => self.visit_register(reg, ctx, generic_list),
1498            Statement::Declaration(names) => {
1499                for name in names {
1500                    let new_type = self.new_generic_type(name.loc());
1501                    self.add_equation(TypedExpression::Name(name.clone().inner), new_type);
1502                }
1503                Ok(())
1504            }
1505            Statement::PipelineRegMarker(extra) => {
1506                match extra {
1507                    Some(PipelineRegMarkerExtra::Condition(cond)) => {
1508                        self.visit_expression(cond, ctx, generic_list)?;
1509                        self.unify_expression_generic_error(
1510                            cond,
1511                            &t_bool(ctx.symtab).at_loc(cond),
1512                            ctx,
1513                        )?;
1514                    }
1515                    Some(PipelineRegMarkerExtra::Count {
1516                        count: _,
1517                        count_typeexpr_id: _,
1518                    }) => {}
1519                    None => {}
1520                }
1521
1522                let current_stage_depth = self
1523                    .pipeline_state
1524                    .clone()
1525                    .ok_or_else(|| {
1526                        diag_anyhow!(stmt, "Found a pipeline reg marker in a non-pipeline")
1527                    })?
1528                    .current_stage_depth;
1529
1530                let new_depth = self.new_generic_tlint(stmt.loc());
1531                let offset = match extra {
1532                    Some(PipelineRegMarkerExtra::Count {
1533                        count,
1534                        count_typeexpr_id,
1535                    }) => {
1536                        let var = self.hir_type_expr_to_var(count, generic_list)?;
1537                        self.add_equation(TypedExpression::Id(*count_typeexpr_id), var.clone());
1538                        var
1539                    }
1540                    Some(PipelineRegMarkerExtra::Condition(_)) | None => {
1541                        TypeVar::Known(stmt.loc(), KnownType::Integer(1.to_bigint()), vec![])
1542                    }
1543                };
1544
1545                let total_depth = ConstraintExpr::Sum(
1546                    Box::new(ConstraintExpr::Var(offset)),
1547                    Box::new(ConstraintExpr::Var(current_stage_depth)),
1548                );
1549                self.pipeline_state
1550                    .as_mut()
1551                    .expect("Expected to have a pipeline state")
1552                    .current_stage_depth = new_depth.clone();
1553
1554                self.add_constraint(
1555                    new_depth.clone(),
1556                    total_depth,
1557                    stmt.loc(),
1558                    &new_depth,
1559                    ConstraintSource::PipelineRegCount {
1560                        reg: stmt.loc(),
1561                        total: self.get_pipeline_state(stmt)?.total_depth.loc(),
1562                    },
1563                );
1564
1565                Ok(())
1566            }
1567            Statement::Label(name) => {
1568                let key = TypedExpression::Name(name.inner.clone());
1569                let var = if !self.equations.contains_key(&key) {
1570                    let var = self.new_generic_tlint(name.loc());
1571                    self.trace_stack.push(TraceStackEntry::AddingPipelineLabel(
1572                        name.inner.clone(),
1573                        var.clone(),
1574                    ));
1575                    self.add_equation(key.clone(), var.clone());
1576                    var
1577                } else {
1578                    let var = self.equations.get(&key).unwrap().clone();
1579                    self.trace_stack
1580                        .push(TraceStackEntry::RecoveringPipelineLabel(
1581                            name.inner.clone(),
1582                            var.clone(),
1583                        ));
1584                    var
1585                };
1586                // Safe unwrap, unifying with a fresh var
1587                self.unify(
1588                    &var,
1589                    &self.get_pipeline_state(name)?.current_stage_depth.clone(),
1590                    ctx,
1591                )
1592                .unwrap();
1593                Ok(())
1594            }
1595            Statement::WalSuffixed { .. } => Ok(()),
1596            Statement::Assert(expr) => {
1597                self.visit_expression(expr, ctx, generic_list)?;
1598                self.unify_expression_generic_error(expr, &t_bool(ctx.symtab).at_loc(stmt), ctx)?;
1599                Ok(())
1600            }
1601            Statement::Set { target, value } => {
1602                self.visit_expression(target, ctx, generic_list)?;
1603                self.visit_expression(value, ctx, generic_list)?;
1604
1605                let inner_type = self.new_generic_type(value.loc());
1606                let outer_type = TypeVar::backward(stmt.loc(), inner_type.clone());
1607                self.unify_expression_generic_error(target, &outer_type, ctx)?;
1608                self.unify_expression_generic_error(value, &inner_type, ctx)?;
1609
1610                Ok(())
1611            }
1612        }
1613    }
1614
1615    #[trace_typechecker]
1616    pub fn visit_register(
1617        &mut self,
1618        reg: &Register,
1619        ctx: &Context,
1620        generic_list: &GenericListToken,
1621    ) -> Result<()> {
1622        self.visit_pattern(&reg.pattern, ctx, generic_list)?;
1623
1624        let type_spec_type = match &reg.value_type {
1625            Some(t) => Some(self.type_var_from_hir(t.loc(), t, generic_list)?.at_loc(t)),
1626            None => None,
1627        };
1628
1629        // We need to do this before visiting value, in case it constrains the
1630        // type of the identifiers in the pattern
1631        if let Some(tvar) = &type_spec_type {
1632            self.unify(&TypedExpression::Id(reg.pattern.id), tvar, ctx)
1633                .into_diagnostic_no_expected_source(
1634                    reg.pattern.loc(),
1635                    error_pattern_type_mismatch(tvar.loc()),
1636                )?;
1637        }
1638
1639        self.visit_expression(&reg.clock, ctx, generic_list)?;
1640        self.visit_expression(&reg.value, ctx, generic_list)?;
1641
1642        if let Some(tvar) = &type_spec_type {
1643            self.unify(&reg.value, tvar, ctx)
1644                .into_default_diagnostic(reg.value.loc())?;
1645        }
1646
1647        if let Some((rst_cond, rst_value)) = &reg.reset {
1648            self.visit_expression(rst_cond, ctx, generic_list)?;
1649            self.visit_expression(rst_value, ctx, generic_list)?;
1650            // Ensure cond is a boolean
1651            self.unify(&rst_cond.inner, &t_bool(ctx.symtab).at_loc(rst_cond), ctx)
1652                .into_diagnostic(
1653                    rst_cond.loc(),
1654                    |diag,
1655                     Tm {
1656                         g: got,
1657                         e: _expected,
1658                     }| {
1659                        diag.message(format!(
1660                            "Register reset condition must be a bool, got {got}"
1661                        ))
1662                        .primary_label("expected bool")
1663                    },
1664                )?;
1665
1666            // Ensure the reset value has the same type as the register itself
1667            self.unify(&rst_value.inner, &reg.value.inner, ctx)
1668                .into_diagnostic(
1669                    rst_value.loc(),
1670                    |diag,
1671                     Tm {
1672                         g: got,
1673                         e: expected,
1674                     }| {
1675                        diag.message(format!(
1676                            "Register reset value mismatch. Expected {expected} got {got}"
1677                        ))
1678                        .primary_label(format!("expected {expected}"))
1679                        .secondary_label(&reg.pattern, format!("because this has type {expected}"))
1680                    },
1681                )?;
1682        }
1683
1684        if let Some(initial) = &reg.initial {
1685            self.visit_expression(initial, ctx, generic_list)?;
1686
1687            self.unify(&initial.inner, &reg.value.inner, ctx)
1688                .into_diagnostic(
1689                    initial.loc(),
1690                    |diag,
1691                     Tm {
1692                         g: got,
1693                         e: expected,
1694                     }| {
1695                        diag.message(format!(
1696                            "Register initial value mismatch. Expected {expected} got {got}"
1697                        ))
1698                        .primary_label(format!("expected {expected}, got {got}"))
1699                        .secondary_label(&reg.pattern, format!("because this has type {got}"))
1700                    },
1701                )?;
1702        }
1703
1704        self.unify(&reg.clock, &t_clock(ctx.symtab).at_loc(&reg.clock), ctx)
1705            .into_diagnostic(
1706                reg.clock.loc(),
1707                |diag,
1708                 Tm {
1709                     g: got,
1710                     e: _expected,
1711                 }| {
1712                    diag.message(format!("Expected clock, got {got}"))
1713                        .primary_label("expected clock")
1714                },
1715            )?;
1716
1717        self.unify(&TypedExpression::Id(reg.pattern.id), &reg.value, ctx)
1718            .into_diagnostic(
1719                reg.pattern.loc(),
1720                error_pattern_type_mismatch(reg.value.loc()),
1721            )?;
1722
1723        Ok(())
1724    }
1725
1726    #[trace_typechecker]
1727    pub fn visit_trait_spec(
1728        &mut self,
1729        trait_spec: &Loc<TraitSpec>,
1730        generic_list: &GenericListToken,
1731    ) -> Result<Loc<TraitReq>> {
1732        let type_params = if let Some(type_params) = &trait_spec.inner.type_params {
1733            type_params
1734                .inner
1735                .iter()
1736                .map(|te| self.hir_type_expr_to_var(te, generic_list))
1737                .collect::<Result<_>>()?
1738        } else {
1739            vec![]
1740        };
1741
1742        Ok(TraitReq {
1743            name: trait_spec.name.clone(),
1744            type_params,
1745        }
1746        .at_loc(trait_spec))
1747    }
1748
1749    #[trace_typechecker]
1750    pub fn visit_trait_bounds(
1751        &mut self,
1752        target: &Loc<NameID>,
1753        traits: &[Loc<TraitSpec>],
1754        generic_list: &GenericListToken,
1755    ) -> Result<()> {
1756        let trait_reqs = traits
1757            .iter()
1758            .map(|spec| self.visit_trait_spec(spec, generic_list))
1759            .collect::<Result<BTreeSet<_>>>()?
1760            .into_iter()
1761            .collect_vec();
1762
1763        if !trait_reqs.is_empty() {
1764            let trait_list = TraitList::from_vec(trait_reqs);
1765
1766            let generic_list = self.generic_lists.get_mut(generic_list).unwrap();
1767
1768            let Some(tvar) = generic_list.get(&target.inner) else {
1769                return Err(Diagnostic::bug(
1770                    target,
1771                    "Couldn't find generic from where clause in generic list",
1772                )
1773                .primary_label(format!(
1774                    "Generic type {} not found in generic list",
1775                    target.inner
1776                )));
1777            };
1778
1779            self.trace_stack.push(TraceStackEntry::AddingTraitBounds(
1780                tvar.clone(),
1781                trait_list.clone(),
1782            ));
1783
1784            let TypeVar::Unknown(loc, id, old_trait_list, MetaType::Type) = tvar else {
1785                return Err(Diagnostic::bug(
1786                    target,
1787                    "Trait bounds on known type or type-level integer",
1788                )
1789                .primary_label(format!(
1790                    "Trait bounds on {}, which should've been caught in ast-lowering",
1791                    target.inner
1792                )));
1793            };
1794
1795            let new_tvar = TypeVar::Unknown(
1796                *loc,
1797                *id,
1798                old_trait_list.clone().extend(trait_list),
1799                MetaType::Type,
1800            );
1801
1802            trace!(
1803                "Adding trait bound {} on type {}",
1804                new_tvar.display_with_meta(true),
1805                target.inner
1806            );
1807
1808            generic_list.insert(target.inner.clone(), new_tvar);
1809        }
1810
1811        Ok(())
1812    }
1813
1814    pub fn visit_const_generic_with_id(
1815        &mut self,
1816        gen: &Loc<ConstGenericWithId>,
1817        generic_list_token: &GenericListToken,
1818        constraint_source: ConstraintSource,
1819        ctx: &Context,
1820    ) -> Result<TypeVar> {
1821        let var = match &gen.inner.inner {
1822            ConstGeneric::Name(name) => {
1823                let ty = &ctx.symtab.type_symbol_by_id(&name);
1824                match &ty.inner {
1825                    TypeSymbol::Declared(_, _) => {
1826                        return Err(Diagnostic::error(
1827                            name,
1828                            "{type_decl_kind} cannot be used in a const generic expression",
1829                        )
1830                        .primary_label("Type in  const generic")
1831                        .secondary_label(ty, "{name} is declared here"))
1832                    }
1833                    TypeSymbol::GenericArg { .. } | TypeSymbol::GenericMeta(MetaType::Type) => {
1834                        return Err(Diagnostic::error(
1835                            name,
1836                            "Generic types cannot be used in const generic expressions",
1837                        )
1838                        .primary_label("Type in  const generic")
1839                        .secondary_label(ty, "{name} is declared here")
1840                        .span_suggest_insert_before(
1841                            "Consider making this a value",
1842                            ty.loc(),
1843                            "#uint ",
1844                        ))
1845                    }
1846                    TypeSymbol::GenericMeta(MetaType::Number) => {
1847                        self.new_generic_tlnumber(gen.loc())
1848                    }
1849                    TypeSymbol::GenericMeta(MetaType::Int) => self.new_generic_tlint(gen.loc()),
1850                    TypeSymbol::GenericMeta(MetaType::Uint) => self.new_generic_tluint(gen.loc()),
1851                    TypeSymbol::GenericMeta(MetaType::Bool) => self.new_generic_tlbool(gen.loc()),
1852                    TypeSymbol::GenericMeta(MetaType::Any) => {
1853                        diag_bail!(gen, "Found any meta type")
1854                    }
1855                    TypeSymbol::Alias(_) => {
1856                        return Err(Diagnostic::error(
1857                            gen,
1858                            "Aliases are not currently supported in const generics",
1859                        )
1860                        .secondary_label(ty, "Alias defined here"))
1861                    }
1862                }
1863            }
1864            ConstGeneric::Const(_)
1865            | ConstGeneric::Add(_, _)
1866            | ConstGeneric::Sub(_, _)
1867            | ConstGeneric::Mul(_, _)
1868            | ConstGeneric::Div(_, _)
1869            | ConstGeneric::Mod(_, _)
1870            | ConstGeneric::UintBitsToFit(_) => self.new_generic_tlnumber(gen.loc()),
1871            ConstGeneric::Eq(_, _) | ConstGeneric::NotEq(_, _) => {
1872                self.new_generic_tlbool(gen.loc())
1873            }
1874        };
1875        let constraint = self.visit_const_generic(&gen.inner.inner, generic_list_token)?;
1876        self.add_equation(TypedExpression::Id(gen.id), var.clone());
1877        self.add_constraint(var.clone(), constraint, gen.loc(), &var, constraint_source);
1878        Ok(var)
1879    }
1880
1881    #[trace_typechecker]
1882    pub fn visit_const_generic(
1883        &self,
1884        constraint: &ConstGeneric,
1885        generic_list: &GenericListToken,
1886    ) -> Result<ConstraintExpr> {
1887        let wrap = |lhs,
1888                    rhs,
1889                    wrapper: fn(Box<ConstraintExpr>, Box<ConstraintExpr>) -> ConstraintExpr|
1890         -> Result<_> {
1891            Ok(wrapper(
1892                Box::new(self.visit_const_generic(lhs, generic_list)?),
1893                Box::new(self.visit_const_generic(rhs, generic_list)?),
1894            ))
1895        };
1896        let constraint = match constraint {
1897            ConstGeneric::Name(n) => {
1898                let var = self.get_generic_list(generic_list).get(n).ok_or_else(|| {
1899                    Diagnostic::bug(n, "Found non-generic argument in where clause")
1900                })?;
1901                ConstraintExpr::Var(self.check_var_for_replacement(var.clone()))
1902            }
1903            ConstGeneric::Const(val) => ConstraintExpr::Integer(val.clone()),
1904            ConstGeneric::Add(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Sum)?,
1905            ConstGeneric::Sub(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Difference)?,
1906            ConstGeneric::Mul(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Product)?,
1907            ConstGeneric::Div(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Div)?,
1908            ConstGeneric::Mod(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Mod)?,
1909            ConstGeneric::Eq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Eq)?,
1910            ConstGeneric::NotEq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::NotEq)?,
1911            ConstGeneric::UintBitsToFit(a) => ConstraintExpr::UintBitsToRepresent(Box::new(
1912                self.visit_const_generic(a, generic_list)?,
1913            )),
1914        };
1915        Ok(constraint)
1916    }
1917}
1918
1919// Private helper functions
1920impl TypeState {
1921    fn new_typeid(&mut self) -> u64 {
1922        let result = self.next_typeid;
1923        self.next_typeid += 1;
1924        result
1925    }
1926
1927    fn check_var_for_replacement(&self, var: TypeVar) -> TypeVar {
1928        if let Some(new) = self.replacements.get(&var) {
1929            if new == &var {
1930                panic!("Type var {new} has been replaced by itself");
1931            }
1932            // We need to do this recursively if we have multiple long lived vars
1933            return self.check_var_for_replacement(new.clone());
1934        };
1935        match var {
1936            TypeVar::Known(loc, base, params) => TypeVar::Known(
1937                loc,
1938                base,
1939                params
1940                    .into_iter()
1941                    .map(|p| self.check_var_for_replacement(p))
1942                    .collect(),
1943            ),
1944            TypeVar::Unknown(loc, id, traits, meta) => TypeVar::Unknown(
1945                loc,
1946                id,
1947                TraitList::from_vec(
1948                    traits
1949                        .inner
1950                        .iter()
1951                        .map(|t| {
1952                            TraitReq {
1953                                name: t.inner.name.clone(),
1954                                type_params: t
1955                                    .inner
1956                                    .type_params
1957                                    .iter()
1958                                    .map(|v| self.check_var_for_replacement(v.clone()))
1959                                    .collect(),
1960                            }
1961                            .at_loc(t)
1962                        })
1963                        .collect(),
1964                ),
1965                meta,
1966            ),
1967        }
1968    }
1969
1970    fn check_expr_for_replacement(&self, val: ConstraintExpr) -> ConstraintExpr {
1971        // FIXME: AS this gets more complicated, consider rewriting it to clone `val` and
1972        // then just replacing the inner values, this is error prone when copy pasting
1973        match val {
1974            v @ ConstraintExpr::Integer(_) => v,
1975            v @ ConstraintExpr::Bool(_) => v,
1976            ConstraintExpr::Var(var) => ConstraintExpr::Var(self.check_var_for_replacement(var)),
1977            ConstraintExpr::Sum(lhs, rhs) => ConstraintExpr::Sum(
1978                Box::new(self.check_expr_for_replacement(*lhs)),
1979                Box::new(self.check_expr_for_replacement(*rhs)),
1980            ),
1981            ConstraintExpr::Difference(lhs, rhs) => ConstraintExpr::Difference(
1982                Box::new(self.check_expr_for_replacement(*lhs)),
1983                Box::new(self.check_expr_for_replacement(*rhs)),
1984            ),
1985            ConstraintExpr::Product(lhs, rhs) => ConstraintExpr::Product(
1986                Box::new(self.check_expr_for_replacement(*lhs)),
1987                Box::new(self.check_expr_for_replacement(*rhs)),
1988            ),
1989            ConstraintExpr::Div(lhs, rhs) => ConstraintExpr::Div(
1990                Box::new(self.check_expr_for_replacement(*lhs)),
1991                Box::new(self.check_expr_for_replacement(*rhs)),
1992            ),
1993            ConstraintExpr::Mod(lhs, rhs) => ConstraintExpr::Mod(
1994                Box::new(self.check_expr_for_replacement(*lhs)),
1995                Box::new(self.check_expr_for_replacement(*rhs)),
1996            ),
1997            ConstraintExpr::Eq(lhs, rhs) => ConstraintExpr::Eq(
1998                Box::new(self.check_expr_for_replacement(*lhs)),
1999                Box::new(self.check_expr_for_replacement(*rhs)),
2000            ),
2001            ConstraintExpr::NotEq(lhs, rhs) => ConstraintExpr::NotEq(
2002                Box::new(self.check_expr_for_replacement(*lhs)),
2003                Box::new(self.check_expr_for_replacement(*rhs)),
2004            ),
2005            ConstraintExpr::Sub(inner) => {
2006                ConstraintExpr::Sub(Box::new(self.check_expr_for_replacement(*inner)))
2007            }
2008            ConstraintExpr::UintBitsToRepresent(inner) => ConstraintExpr::UintBitsToRepresent(
2009                Box::new(self.check_expr_for_replacement(*inner)),
2010            ),
2011        }
2012    }
2013
2014    pub fn add_equation(&mut self, expression: TypedExpression, var: TypeVar) {
2015        let var = self.check_var_for_replacement(var);
2016
2017        self.trace_stack.push(TraceStackEntry::AddingEquation(
2018            expression.clone(),
2019            var.clone(),
2020        ));
2021        if let Some(prev) = self.equations.insert(expression.clone(), var.clone()) {
2022            let var = var.clone();
2023            let expr = expression.clone();
2024            println!("{}", format_trace_stack(self));
2025            panic!("Adding equation for {} == {} but a previous eq exists.\n\tIt was previously bound to {}", expr, var, prev)
2026        }
2027    }
2028
2029    fn add_constraint(
2030        &mut self,
2031        lhs: TypeVar,
2032        rhs: ConstraintExpr,
2033        loc: Loc<()>,
2034        inside: &TypeVar,
2035        source: ConstraintSource,
2036    ) {
2037        let replaces = lhs.clone();
2038        let lhs = self.check_var_for_replacement(lhs);
2039        let rhs = self
2040            .check_expr_for_replacement(rhs)
2041            .with_context(&replaces, &inside.clone(), source)
2042            .at_loc(&loc);
2043
2044        self.trace_stack.push(TraceStackEntry::AddingConstraint(
2045            lhs.clone(),
2046            rhs.inner.clone(),
2047        ));
2048
2049        self.constraints.add_int_constraint(lhs, rhs);
2050    }
2051
2052    fn add_requirement(&mut self, requirement: Requirement) {
2053        macro_rules! replace {
2054            ($name:expr) => {
2055                self.check_var_for_replacement($name.inner.clone())
2056                    .at_loc(&$name)
2057            };
2058        }
2059
2060        let replaced = match requirement {
2061            Requirement::HasField {
2062                target_type,
2063                field,
2064                expr,
2065            } => Requirement::HasField {
2066                field,
2067                target_type: replace!(target_type),
2068                expr: replace!(expr),
2069            },
2070            Requirement::HasMethod {
2071                expr_id,
2072                target_type,
2073                trait_list,
2074                method,
2075                expr,
2076                args,
2077                turbofish,
2078                prev_generic_list,
2079                call_kind,
2080            } => Requirement::HasMethod {
2081                expr_id,
2082                target_type: replace!(target_type),
2083                trait_list,
2084                method,
2085                expr: replace!(expr),
2086                args,
2087                turbofish,
2088                prev_generic_list,
2089                call_kind,
2090            },
2091            Requirement::FitsIntLiteral { value, target_type } => Requirement::FitsIntLiteral {
2092                value: match value {
2093                    ConstantInt::Generic(var) => {
2094                        ConstantInt::Generic(replace!(var.clone().nowhere()).inner)
2095                    }
2096                    lit @ ConstantInt::Literal(_) => lit,
2097                },
2098                target_type: replace!(target_type),
2099            },
2100            Requirement::ValidPipelineOffset {
2101                definition_depth,
2102                current_stage,
2103                reference_offset,
2104            } => Requirement::ValidPipelineOffset {
2105                definition_depth: replace!(definition_depth),
2106                current_stage: replace!(current_stage),
2107                reference_offset: replace!(reference_offset),
2108            },
2109            Requirement::RangeIndexInArray { index, size } => Requirement::RangeIndexInArray {
2110                index: replace!(index),
2111                size: replace!(size),
2112            },
2113            Requirement::RangeIndexEndAfterStart { expr, start, end } => {
2114                Requirement::RangeIndexEndAfterStart {
2115                    expr,
2116                    start: replace!(start),
2117                    end: replace!(end),
2118                }
2119            }
2120            Requirement::PositivePipelineDepth { depth } => Requirement::PositivePipelineDepth {
2121                depth: replace!(depth),
2122            },
2123            Requirement::ArrayIndexeeIsNonZero {
2124                index,
2125                array,
2126                array_size,
2127            } => Requirement::ArrayIndexeeIsNonZero {
2128                index,
2129                array: replace!(array),
2130                array_size: replace!(array_size),
2131            },
2132            Requirement::SharedBase(types) => {
2133                Requirement::SharedBase(types.iter().map(|ty| replace!(ty)).collect())
2134            }
2135        };
2136
2137        self.trace_stack
2138            .push(TraceStackEntry::AddRequirement(replaced.clone()));
2139        self.requirements.push(replaced)
2140    }
2141
2142    /// Performs unification but does not update constraints. This is done to avoid
2143    /// updating constraints more often than necessary. Technically, constraints can
2144    /// be updated even less often, but `unify` is a pretty natural point to do so.
2145
2146    fn unify_inner(
2147        &mut self,
2148        e1: &impl HasType,
2149        e2: &impl HasType,
2150        ctx: &Context,
2151    ) -> std::result::Result<TypeVar, UnificationError> {
2152        let v1 = e1
2153            .get_type(self)
2154            .expect("Tried to unify types but the lhs was not found");
2155        let v2 = e2
2156            .get_type(self)
2157            .expect("Tried to unify types but the rhs was not found");
2158
2159        trace!("Unifying {v1} with {v2}");
2160
2161        let v1 = self.check_var_for_replacement(v1);
2162        let v2 = self.check_var_for_replacement(v2);
2163
2164        self.trace_stack
2165            .push(TraceStackEntry::TryingUnify(v1.clone(), v2.clone()));
2166
2167        // Copy the original types so we can properly trace the progress of the
2168        // unification
2169        let v1cpy = v1.clone();
2170        let v2cpy = v2.clone();
2171
2172        macro_rules! err_producer {
2173            () => {{
2174                self.trace_stack
2175                    .push(TraceStackEntry::Message("Produced error".to_string()));
2176                UnificationError::Normal(Tm {
2177                    g: UnificationTrace::new(self.check_var_for_replacement(v1)),
2178                    e: UnificationTrace::new(self.check_var_for_replacement(v2)),
2179                })
2180            }};
2181        }
2182        macro_rules! meta_err_producer {
2183            () => {{
2184                self.trace_stack
2185                    .push(TraceStackEntry::Message("Produced error".to_string()));
2186                UnificationError::MetaMismatch(Tm {
2187                    g: UnificationTrace::new(self.check_var_for_replacement(v1)),
2188                    e: UnificationTrace::new(self.check_var_for_replacement(v2)),
2189                })
2190            }};
2191        }
2192
2193        macro_rules! unify_if {
2194            ($condition:expr, $new_type:expr, $replaced_type:expr) => {
2195                if $condition {
2196                    Ok(($new_type, $replaced_type))
2197                } else {
2198                    Err(err_producer!())
2199                }
2200            };
2201        }
2202
2203        macro_rules! try_with_context {
2204            ($value: expr) => {
2205                try_with_context!($value, v1, v2)
2206            };
2207            ($value: expr, $v1:expr, $v2:expr) => {
2208                match $value {
2209                    Ok(result) => result,
2210                    e => {
2211                        self.trace_stack
2212                            .push(TraceStackEntry::Message("Adding context".to_string()));
2213                        return e.add_context($v1.clone(), $v2.clone());
2214                    }
2215                }
2216            };
2217        }
2218
2219        macro_rules! unify_params_ {
2220            ($p1:expr, $p2:expr) => {
2221                if $p1.len() != $p2.len() {
2222                    return Err(err_producer!());
2223                }
2224
2225                for (t1, t2) in $p1.iter().zip($p2.iter()) {
2226                    try_with_context!(self.unify_inner(t1, t2, ctx));
2227                }
2228            };
2229        }
2230
2231        // Figure out the most general type, and take note if we need to
2232        // do any replacement of the types in the rest of the state
2233        let result = match (&v1, &v2) {
2234            (TypeVar::Known(_, t1, p1), TypeVar::Known(_, t2, p2)) => {
2235                macro_rules! unify_params {
2236                    () => {
2237                        unify_params_!(p1, p2)
2238                    };
2239                }
2240                match (t1, t2) {
2241                    (KnownType::Integer(val1), KnownType::Integer(val2)) => {
2242                        unify_if!(val1 == val2, v1, vec![])
2243                    }
2244                    (KnownType::Named(n1), KnownType::Named(n2)) => {
2245                        match (
2246                            &ctx.symtab.type_symbol_by_id(n1).inner,
2247                            &ctx.symtab.type_symbol_by_id(n2).inner,
2248                        ) {
2249                            (TypeSymbol::Declared(_, _), TypeSymbol::Declared(_, _)) => {
2250                                if n1 != n2 {
2251                                    return Err(err_producer!());
2252                                }
2253
2254                                unify_params!();
2255                                let new_ts1 = ctx.symtab.type_symbol_by_id(n1).inner;
2256                                let new_ts2 = ctx.symtab.type_symbol_by_id(n2).inner;
2257                                let new_v1 = e1
2258                                    .get_type(self)
2259                                    .expect("Tried to unify types but the lhs was not found");
2260                                unify_if!(
2261                                    new_ts1 == new_ts2,
2262                                    self.check_var_for_replacement(new_v1),
2263                                    vec![]
2264                                )
2265                            }
2266                            (TypeSymbol::Declared(_, _), TypeSymbol::GenericArg { traits }) => {
2267                                if !traits.is_empty() {
2268                                    todo!("Implement trait unifictaion");
2269                                }
2270                                Ok((v1, vec![v2]))
2271                            }
2272                            (TypeSymbol::GenericArg { traits }, TypeSymbol::Declared(_, _)) => {
2273                                if !traits.is_empty() {
2274                                    todo!("Implement trait unifictaion");
2275                                }
2276                                Ok((v2, vec![v1]))
2277                            }
2278                            (
2279                                TypeSymbol::GenericArg { traits: ltraits },
2280                                TypeSymbol::GenericArg { traits: rtraits },
2281                            ) => {
2282                                if !ltraits.is_empty() || !rtraits.is_empty() {
2283                                    todo!("Implement trait unifictaion");
2284                                }
2285                                Ok((v1, vec![v2]))
2286                            }
2287                            (TypeSymbol::Declared(_, _), TypeSymbol::GenericMeta(_)) => todo!(),
2288                            (TypeSymbol::GenericArg { traits: _ }, TypeSymbol::GenericMeta(_)) => {
2289                                todo!()
2290                            }
2291                            (TypeSymbol::GenericMeta(_), TypeSymbol::Declared(_, _)) => todo!(),
2292                            (TypeSymbol::GenericMeta(_), TypeSymbol::GenericArg { traits: _ }) => {
2293                                todo!()
2294                            }
2295                            (TypeSymbol::Alias(_), _) | (_, TypeSymbol::Alias(_)) => {
2296                                return Err(UnificationError::Specific(Diagnostic::bug(
2297                                    ().nowhere(),
2298                                    "Encountered a raw type alias during unification",
2299                                )))
2300                            }
2301                            (TypeSymbol::GenericMeta(_), TypeSymbol::GenericMeta(_)) => todo!(),
2302                        }
2303                    }
2304                    (KnownType::Array, KnownType::Array)
2305                    | (KnownType::Tuple, KnownType::Tuple)
2306                    | (KnownType::Wire, KnownType::Wire)
2307                    | (KnownType::Inverted, KnownType::Inverted) => {
2308                        unify_params!();
2309                        // Note, replacements should only occur when a variable goes from Unknown
2310                        // to Known, not when the base is the same. Replacements take care
2311                        // of parameters. Therefore, None is returned here
2312                        Ok((self.check_var_for_replacement(v2), vec![]))
2313                    }
2314                    (_, _) => Err(err_producer!()),
2315                }
2316            }
2317            // Unknown with unknown requires merging traits
2318            (
2319                TypeVar::Unknown(loc1, _, traits1, meta1),
2320                TypeVar::Unknown(loc2, _, traits2, meta2),
2321            ) => {
2322                let new_loc = if meta1.is_more_concrete_than(meta2) {
2323                    loc1
2324                } else {
2325                    loc2
2326                };
2327                let new_t = match unify_meta(meta1, meta2) {
2328                    Some(meta @ MetaType::Any) | Some(meta @ MetaType::Number) => {
2329                        if traits1.inner.is_empty() || traits2.inner.is_empty() {
2330                            return Err(UnificationError::Specific(diag_anyhow!(
2331                                new_loc,
2332                                "Inferred an any meta-type with traits"
2333                            )));
2334                        }
2335                        self.new_generic_with_meta(*loc1, meta)
2336                    }
2337                    Some(MetaType::Type) => {
2338                        let new_trait_names = traits1
2339                            .inner
2340                            .iter()
2341                            .chain(traits2.inner.iter())
2342                            .map(|t| t.name.clone())
2343                            .collect::<BTreeSet<_>>()
2344                            .into_iter()
2345                            .collect::<Vec<_>>();
2346
2347                        let new_traits = new_trait_names
2348                            .iter()
2349                            .map(
2350                                |name| match (traits1.get_trait(name), traits2.get_trait(name)) {
2351                                    (Some(req1), Some(req2)) => {
2352                                        let new_params = req1
2353                                            .inner
2354                                            .type_params
2355                                            .iter()
2356                                            .zip(req2.inner.type_params.iter())
2357                                            .map(|(p1, p2)| self.unify(p1, p2, ctx))
2358                                            .collect::<std::result::Result<_, UnificationError>>(
2359                                            )?;
2360
2361                                        Ok(TraitReq {
2362                                            name: name.clone(),
2363                                            type_params: new_params,
2364                                        }
2365                                        .at_loc(req1))
2366                                    }
2367                                    (Some(t), None) => Ok(t.clone()),
2368                                    (None, Some(t)) => Ok(t.clone()),
2369                                    (None, None) => panic!("Found a trait but neither side has it"),
2370                                },
2371                            )
2372                            .collect::<std::result::Result<Vec<_>, UnificationError>>()?;
2373
2374                        self.new_generic_with_traits(*new_loc, TraitList::from_vec(new_traits))
2375                    }
2376                    Some(MetaType::Int) => self.new_generic_tlint(*new_loc),
2377                    Some(MetaType::Uint) => self.new_generic_tluint(*new_loc),
2378                    Some(MetaType::Bool) => self.new_generic_tlbool(*new_loc),
2379                    None => return Err(meta_err_producer!()),
2380                };
2381                Ok((new_t, vec![v1, v2]))
2382            }
2383            (
2384                other @ TypeVar::Known(loc, base, params),
2385                uk @ TypeVar::Unknown(ukloc, _, traits, meta),
2386            )
2387            | (
2388                uk @ TypeVar::Unknown(ukloc, _, traits, meta),
2389                other @ TypeVar::Known(loc, base, params),
2390            ) => {
2391                let trait_is_expected = match (&v1, &v2) {
2392                    (TypeVar::Known(_, _, _), _) => true,
2393                    _ => false,
2394                };
2395
2396                self.ensure_impls(other, traits, trait_is_expected, ukloc, ctx)?;
2397
2398                let mut new_params = params.clone();
2399                for t in &traits.inner {
2400                    if !t.type_params.is_empty() {
2401                        if t.type_params.len() != params.len() {
2402                            return Err(err_producer!());
2403                        }
2404
2405                        // If we don't cheat a bit, we'll get bad error messages in this case when
2406                        // we unify, for example, `Number<10>` with `uint<9>`. Since we know the
2407                        // outer types match already, we'll create a fake type for the lhs where
2408                        // we preemptively crate uint<T>
2409                        let fake_type = TypeVar::Known(*loc, base.clone(), t.type_params.clone());
2410
2411                        new_params = t
2412                            .type_params
2413                            .iter()
2414                            .zip(new_params.iter())
2415                            .map(|(t1, t2)| {
2416                                Ok(try_with_context!(
2417                                    self.unify_inner(t1, t2, ctx),
2418                                    fake_type,
2419                                    other
2420                                ))
2421                            })
2422                            .collect::<std::result::Result<_, _>>()?;
2423                    }
2424                }
2425
2426                match (base, meta) {
2427                    // Any matches all types
2428                    (_, MetaType::Any)
2429                    // All types are types
2430                    | (KnownType::Named(_), MetaType::Type)
2431                    | (KnownType::Tuple, MetaType::Type)
2432                    | (KnownType::Array, MetaType::Type)
2433                    | (KnownType::Wire, MetaType::Type)
2434                    | (KnownType::Bool(_), MetaType::Bool)
2435                    | (KnownType::Inverted, MetaType::Type)
2436                    // Integers match ints and numbers
2437                    | (KnownType::Integer(_), MetaType::Int)
2438                    | (KnownType::Integer(_), MetaType::Number)
2439                    => {
2440                        let new = TypeVar::Known(*loc, base.clone(), new_params);
2441
2442                        Ok((new, vec![uk.clone()]))
2443                    },
2444                    // Integers match uints but only if the concrete integer is >= 0
2445                    (KnownType::Integer(val), MetaType::Uint)
2446                    => {
2447                        if val < &0.to_bigint() {
2448                            Err(meta_err_producer!())
2449                        } else {
2450                            let new = TypeVar::Known(*loc, base.clone(), new_params);
2451
2452                            Ok((new, vec![uk.clone()]))
2453                        }
2454                    },
2455
2456                    // Integer with type
2457                    (KnownType::Integer(_), MetaType::Type) => Err(meta_err_producer!()),
2458
2459                    // Bools only unify with any or bool
2460                    (_, MetaType::Bool) => Err(meta_err_producer!()),
2461                    (KnownType::Bool(_), _) => Err(meta_err_producer!()),
2462
2463                    // Type with integer
2464                    (KnownType::Named(_), MetaType::Int | MetaType::Number | MetaType::Uint)
2465                    | (KnownType::Tuple, MetaType::Int | MetaType::Uint | MetaType::Number)
2466                    | (KnownType::Array, MetaType::Int | MetaType::Uint | MetaType::Number)
2467                    | (KnownType::Wire, MetaType::Int | MetaType::Uint | MetaType::Number)
2468                    | (KnownType::Inverted, MetaType::Int | MetaType::Uint | MetaType::Number)
2469                    => Err(meta_err_producer!())
2470                }
2471            }
2472        };
2473
2474        let (new_type, replaced_types) = result?;
2475
2476        self.trace_stack.push(TraceStackEntry::Unified(
2477            v1cpy,
2478            v2cpy,
2479            new_type.clone(),
2480            replaced_types.clone(),
2481        ));
2482
2483        for replaced_type in replaced_types {
2484            self.replacements
2485                .insert(replaced_type.clone(), new_type.clone());
2486
2487            for rhs in self.equations.values_mut() {
2488                TypeState::replace_type_var(rhs, &replaced_type, &new_type)
2489            }
2490            for generic_list in &mut self.generic_lists.values_mut() {
2491                for (_, lhs) in generic_list.iter_mut() {
2492                    TypeState::replace_type_var(lhs, &replaced_type, &new_type)
2493                }
2494            }
2495            for traits in &mut self.trait_impls.inner.values_mut() {
2496                for TraitImpl {
2497                    name: _,
2498                    type_params,
2499                    impl_block: _,
2500                } in traits.iter_mut()
2501                {
2502                    for tvar in type_params.iter_mut() {
2503                        TypeState::replace_type_var(tvar, &replaced_type, &new_type);
2504                    }
2505                }
2506            }
2507            for requirement in &mut self.requirements {
2508                requirement.replace_type_var(&replaced_type, &new_type)
2509            }
2510
2511            if let Some(PipelineState {
2512                current_stage_depth,
2513                pipeline_loc: _,
2514                total_depth,
2515            }) = &mut self.pipeline_state
2516            {
2517                TypeState::replace_type_var(current_stage_depth, &replaced_type, &new_type);
2518                TypeState::replace_type_var(&mut total_depth.inner, &replaced_type, &new_type);
2519            }
2520
2521            self.constraints.inner = self
2522                .constraints
2523                .inner
2524                .clone()
2525                .into_iter()
2526                .map(|(mut lhs, mut rhs)| {
2527                    TypeState::replace_type_var(&mut lhs, &replaced_type, &new_type);
2528                    TypeState::replace_type_var_in_constraint_rhs(
2529                        &mut rhs,
2530                        &replaced_type,
2531                        &new_type,
2532                    );
2533
2534                    (lhs, rhs)
2535                })
2536                .collect()
2537        }
2538
2539        Ok(new_type)
2540    }
2541
2542    #[tracing::instrument(level = "trace", skip_all)]
2543    pub fn unify(
2544        &mut self,
2545        e1: &impl HasType,
2546        e2: &impl HasType,
2547        ctx: &Context,
2548    ) -> std::result::Result<TypeVar, UnificationError> {
2549        let new_type = self.unify_inner(e1, e2, ctx)?;
2550
2551        // With replacement done, some of our constraints may have been updated to provide
2552        // more type inference information. Try to do unification of those new constraints too
2553        loop {
2554            trace!("Updating constraints");
2555            let new_info = self.constraints.update_type_level_value_constraints();
2556
2557            if new_info.is_empty() {
2558                break;
2559            }
2560
2561            for constraint in new_info {
2562                trace!(
2563                    "Constraint replaces {} with {:?}",
2564                    constraint.inner.0,
2565                    constraint.inner.1
2566                );
2567
2568                let ((var, replacement), loc) = constraint.split_loc();
2569
2570                self.trace_stack
2571                    .push(TraceStackEntry::InferringFromConstraints(
2572                        var.clone(),
2573                        replacement.val.clone(),
2574                    ));
2575
2576                let var = self.check_var_for_replacement(var);
2577
2578                // NOTE: safe unwrap. We already checked the constraint above
2579                let expected_type = replacement.val;
2580                let result = self.unify_inner(&expected_type.clone().at_loc(&loc), &var, ctx);
2581                let is_meta_error = matches!(result, Err(UnificationError::MetaMismatch { .. }));
2582                match result {
2583                    Ok(_) => {}
2584                    Err(UnificationError::Normal(Tm { mut e, mut g }))
2585                    | Err(UnificationError::MetaMismatch(Tm { mut e, mut g })) => {
2586                        let mut source_lhs = replacement.context.inside.clone();
2587                        Self::replace_type_var(
2588                            &mut source_lhs,
2589                            &replacement.context.replaces,
2590                            e.outer(),
2591                        );
2592                        let mut source_rhs = replacement.context.inside.clone();
2593                        Self::replace_type_var(
2594                            &mut source_rhs,
2595                            &replacement.context.replaces,
2596                            g.outer(),
2597                        );
2598                        e.inside.replace(source_lhs);
2599                        g.inside.replace(source_rhs);
2600                        return Err(UnificationError::FromConstraints {
2601                            got: g,
2602                            expected: e,
2603                            source: replacement.context.source,
2604                            loc,
2605                            is_meta_error,
2606                        });
2607                    }
2608                    Err(
2609                        e @ UnificationError::FromConstraints { .. }
2610                        | e @ UnificationError::Specific { .. }
2611                        | e @ UnificationError::UnsatisfiedTraits { .. },
2612                    ) => return Err(e),
2613                };
2614            }
2615        }
2616
2617        Ok(new_type)
2618    }
2619
2620    fn ensure_impls(
2621        &mut self,
2622        var: &TypeVar,
2623        traits: &TraitList,
2624        trait_is_expected: bool,
2625        trait_list_loc: &Loc<()>,
2626        ctx: &Context,
2627    ) -> std::result::Result<(), UnificationError> {
2628        self.trace_stack.push(TraceStackEntry::EnsuringImpls(
2629            var.clone(),
2630            traits.clone(),
2631            trait_is_expected,
2632        ));
2633
2634        let number = ctx
2635            .symtab
2636            .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
2637            .expect("Did not find number in symtab")
2638            .0;
2639
2640        macro_rules! error_producer {
2641            ($required_traits:expr) => {
2642                if trait_is_expected {
2643                    if $required_traits.inner.len() == 1
2644                        && $required_traits
2645                            .get_trait(&TraitName::Named(number.clone().nowhere()))
2646                            .is_some()
2647                    {
2648                        Err(UnificationError::Normal(Tm {
2649                            e: UnificationTrace::new(
2650                                self.new_generic_with_traits(*trait_list_loc, $required_traits),
2651                            ),
2652                            g: UnificationTrace::new(var.clone()),
2653                        }))
2654                    } else {
2655                        Err(UnificationError::UnsatisfiedTraits {
2656                            var: var.clone(),
2657                            traits: $required_traits.inner,
2658                            target_loc: trait_list_loc.clone(),
2659                        })
2660                    }
2661                } else {
2662                    Err(UnificationError::Normal(Tm {
2663                        e: UnificationTrace::new(var.clone()),
2664                        g: UnificationTrace::new(
2665                            self.new_generic_with_traits(*trait_list_loc, $required_traits),
2666                        ),
2667                    }))
2668                }
2669            };
2670        }
2671
2672        match var {
2673            TypeVar::Known(_, known, _) if known.into_impl_target().is_some() => {
2674                let Some(target) = known.into_impl_target() else {
2675                    unreachable!()
2676                };
2677
2678                let unsatisfied = traits
2679                    .inner
2680                    .iter()
2681                    .filter(|trait_req| {
2682                        // Number is special cased for now because we can't impl traits
2683                        // on generic types
2684                        if trait_req.name.name_loc().is_some_and(|n| n.inner == number) {
2685                            let int_type = &ctx
2686                                .symtab
2687                                .lookup_type_symbol(&Path::from_strs(&["int"]).nowhere())
2688                                .expect("The type int was not in the symtab")
2689                                .0;
2690                            let uint_type = &ctx
2691                                .symtab
2692                                .lookup_type_symbol(&Path::from_strs(&["uint"]).nowhere())
2693                                .expect("The type uint was not in the symtab")
2694                                .0;
2695
2696                            !(target == ImplTarget::Named(int_type.clone())
2697                                || target == ImplTarget::Named(uint_type.clone()))
2698                        } else {
2699                            if let Some(impld) = self.trait_impls.inner.get(&target) {
2700                                !impld.iter().any(|trait_impl| {
2701                                    trait_impl.name == trait_req.name
2702                                        && trait_impl.type_params == trait_req.type_params
2703                                })
2704                            } else {
2705                                true
2706                            }
2707                        }
2708                    })
2709                    .cloned()
2710                    .collect::<Vec<_>>();
2711
2712                if unsatisfied.is_empty() {
2713                    Ok(())
2714                } else {
2715                    error_producer!(TraitList::from_vec(unsatisfied.clone()))
2716                }
2717            }
2718            TypeVar::Unknown(_, _, _, _) => {
2719                panic!("running ensure_impls on an unknown type")
2720            }
2721            _ => {
2722                if traits.inner.is_empty() {
2723                    Ok(())
2724                } else {
2725                    error_producer!(traits.clone())
2726                }
2727            }
2728        }
2729    }
2730
2731    fn replace_type_var(in_var: &mut TypeVar, from: &TypeVar, replacement: &TypeVar) {
2732        // First, do recursive replacement
2733        match in_var {
2734            TypeVar::Known(_, _, params) => {
2735                for param in params {
2736                    Self::replace_type_var(param, from, replacement)
2737                }
2738            }
2739            TypeVar::Unknown(_, _, traits, _) => {
2740                for t in traits.inner.iter_mut() {
2741                    for param in t.type_params.iter_mut() {
2742                        Self::replace_type_var(param, from, replacement)
2743                    }
2744                }
2745            }
2746        }
2747
2748        let is_same = match (&in_var, &from) {
2749            // Traits do not matter for comparison
2750            (TypeVar::Unknown(_, id1, _, _), TypeVar::Unknown(_, id2, _, _)) => id1 == id2,
2751            (l, r) => l == r,
2752        };
2753
2754        if is_same {
2755            *in_var = replacement.clone();
2756        }
2757    }
2758
2759    fn replace_type_var_in_constraint_expr(
2760        in_constraint: &mut ConstraintExpr,
2761        from: &TypeVar,
2762        replacement: &TypeVar,
2763    ) {
2764        match in_constraint {
2765            ConstraintExpr::Integer(_) => {}
2766            ConstraintExpr::Bool(_) => {}
2767            ConstraintExpr::Var(v) => {
2768                Self::replace_type_var(v, from, replacement);
2769
2770                match v {
2771                    TypeVar::Known(_, KnownType::Integer(val), _) => {
2772                        *in_constraint = ConstraintExpr::Integer(val.clone())
2773                    }
2774                    _ => {}
2775                }
2776            }
2777            ConstraintExpr::Sum(lhs, rhs)
2778            | ConstraintExpr::Mod(lhs, rhs)
2779            | ConstraintExpr::Div(lhs, rhs)
2780            | ConstraintExpr::Difference(lhs, rhs)
2781            | ConstraintExpr::Product(lhs, rhs)
2782            | ConstraintExpr::Eq(lhs, rhs)
2783            | ConstraintExpr::NotEq(lhs, rhs) => {
2784                Self::replace_type_var_in_constraint_expr(lhs, from, replacement);
2785                Self::replace_type_var_in_constraint_expr(rhs, from, replacement);
2786            }
2787            ConstraintExpr::Sub(i) | ConstraintExpr::UintBitsToRepresent(i) => {
2788                Self::replace_type_var_in_constraint_expr(i, from, replacement);
2789            }
2790        }
2791    }
2792
2793    fn replace_type_var_in_constraint_rhs(
2794        in_constraint: &mut ConstraintRhs,
2795        from: &TypeVar,
2796        replacement: &TypeVar,
2797    ) {
2798        Self::replace_type_var_in_constraint_expr(&mut in_constraint.constraint, from, replacement);
2799        // NOTE: We do not want to replace type variables here as that that removes
2800        // information about where the constraint relates. Instead, this replacement
2801        // is performed when reporting the error
2802        // Self::replace_type_var(&mut in_constraint.from, from, replacement);
2803    }
2804
2805    pub fn unify_expression_generic_error(
2806        &mut self,
2807        expr: &Loc<Expression>,
2808        other: &impl HasType,
2809        ctx: &Context,
2810    ) -> Result<TypeVar> {
2811        self.unify(&expr.inner, other, ctx)
2812            .into_default_diagnostic(expr.loc())
2813    }
2814
2815    pub fn check_requirements(&mut self, ctx: &Context) -> Result<()> {
2816        // Once we are done type checking the rest of the entity, check all requirements
2817        loop {
2818            // Walk through all the requirements, checking each one. If the requirement
2819            // is still undetermined, take note to retain that id, otherwise store the
2820            // replacement to be performed
2821            let (retain, replacements_option): (Vec<_>, Vec<_>) = self
2822                .requirements
2823                .clone()
2824                .iter()
2825                .map(|req| match req.check(self, ctx)? {
2826                    requirements::RequirementResult::NoChange => Ok((true, None)),
2827                    requirements::RequirementResult::Satisfied(replacement) => {
2828                        self.trace_stack
2829                            .push(TraceStackEntry::ResolvedRequirement(req.clone()));
2830                        Ok((false, Some(replacement)))
2831                    }
2832                })
2833                .collect::<Result<Vec<_>>>()?
2834                .into_iter()
2835                .unzip();
2836
2837            let replacements = replacements_option
2838                .into_iter()
2839                .flatten()
2840                .flatten()
2841                .collect::<Vec<_>>();
2842
2843            // Drop all replacements that have now been applied
2844            self.requirements = self
2845                .requirements
2846                .drain(0..)
2847                .zip(retain)
2848                .filter_map(|(req, keep)| if keep { Some(req) } else { None })
2849                .collect();
2850
2851            if replacements.is_empty() {
2852                break;
2853            }
2854
2855            for Replacement { from, to, context } in replacements {
2856                self.unify(&to, &from, ctx)
2857                    .into_diagnostic_or_default(from.loc(), context)?;
2858            }
2859        }
2860
2861        Ok(())
2862    }
2863}
2864
2865impl TypeState {
2866    pub fn print_equations(&self) {
2867        for (lhs, rhs) in &self.equations {
2868            println!(
2869                "{} -> {}",
2870                format!("{lhs}").blue(),
2871                format!("{rhs:?}").green()
2872            )
2873        }
2874
2875        println!("\nReplacments:");
2876
2877        for (lhs, rhs) in self.replacements.iter().sorted() {
2878            println!(
2879                "{} -> {}",
2880                format!("{lhs:?}").blue(),
2881                format!("{rhs:?}").green()
2882            )
2883        }
2884
2885        println!("\n Requirements:");
2886
2887        for requirement in &self.requirements {
2888            println!("{:?}", requirement)
2889        }
2890
2891        println!()
2892    }
2893}
2894
2895pub trait HasType: std::fmt::Debug {
2896    fn get_type(&self, state: &TypeState) -> Result<TypeVar>;
2897}
2898
2899impl HasType for TypeVar {
2900    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2901        Ok(state.check_var_for_replacement(self.clone()))
2902    }
2903}
2904impl HasType for Loc<TypeVar> {
2905    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2906        self.inner.get_type(state)
2907    }
2908}
2909impl HasType for TypedExpression {
2910    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2911        state.type_of(self)
2912    }
2913}
2914impl HasType for Expression {
2915    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2916        state.type_of(&TypedExpression::Id(self.id))
2917    }
2918}
2919impl HasType for Loc<Expression> {
2920    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2921        state.type_of(&TypedExpression::Id(self.inner.id))
2922    }
2923}
2924impl HasType for Pattern {
2925    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2926        state.type_of(&TypedExpression::Id(self.id))
2927    }
2928}
2929impl HasType for Loc<Pattern> {
2930    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2931        state.type_of(&TypedExpression::Id(self.inner.id))
2932    }
2933}
2934impl HasType for Loc<KnownType> {
2935    fn get_type(&self, _state: &TypeState) -> Result<TypeVar> {
2936        Ok(TypeVar::Known(self.loc(), self.inner.clone(), vec![]))
2937    }
2938}
2939impl HasType for NameID {
2940    fn get_type(&self, state: &TypeState) -> Result<TypeVar> {
2941        state.type_of(&TypedExpression::Name(self.clone()))
2942    }
2943}
2944
2945/// Mapping between names and concrete type used for lookup, without being
2946/// able to do more type inference
2947/// Required because we can't serde the whole TypeState
2948#[derive(Serialize, Deserialize)]
2949pub struct TypeMap {
2950    equations: TypeEquations,
2951}
2952
2953impl TypeMap {
2954    pub fn type_of(&self, thing: &TypedExpression) -> Option<&TypeVar> {
2955        self.equations.get(thing)
2956    }
2957}
2958
2959impl From<TypeState> for TypeMap {
2960    fn from(val: TypeState) -> Self {
2961        TypeMap {
2962            equations: val.equations,
2963        }
2964    }
2965}