spade_typeinference/
lib.rs

1// This algorithm is based off the excellent lecture here
2// https://www.youtube.com/watch?v=xJXcZp2vgLs
3//
4// The basic idea is to go through every node in the HIR tree, for every typed thing,
5// add an equation indicating a constraint on that thing. This can onlytydone once
6// and should be done by the visitor for that node. The visitor should then unify
7// types according to the rules of the node.
8
9use std::cell::RefCell;
10use std::cmp::PartialEq;
11use std::collections::{BTreeMap, BTreeSet, HashMap};
12
13use colored::Colorize;
14use fixed_types::{t_int, t_uint};
15use hir::expression::CallKind;
16use hir::{
17    param_util, Binding, ConstGeneric, Parameter, PipelineRegMarkerExtra, TypeExpression, TypeSpec,
18    UnitHead, UnitKind, WalTrace, WhereClause,
19};
20use itertools::{Either, Itertools};
21use method_resolution::{FunctionLikeName, IntoImplTarget};
22use num::{BigInt, BigUint, Zero};
23use replacement::ReplacementStack;
24use serde::{Deserialize, Serialize};
25use spade_common::id_tracker::{ExprID, ImplID};
26use spade_common::num_ext::InfallibleToBigInt;
27use spade_diagnostics::diag_list::{DiagList, ResultExt};
28use spade_diagnostics::{diag_anyhow, diag_bail, Diagnostic};
29use spade_macros::trace_typechecker;
30use spade_types::meta_types::{unify_meta, MetaType};
31use trace_stack::TraceStack;
32use tracing::{info, trace};
33
34use spade_common::location_info::{Loc, WithLocation};
35use spade_common::name::{Identifier, NameID, Path};
36use spade_hir::param_util::{match_args_with_params, Argument};
37use spade_hir::symbol_table::{Patternable, PatternableKind, SymbolTable, TypeSymbol};
38use spade_hir::{self as hir, ConstGenericWithId, ImplTarget};
39use spade_hir::{
40    ArgumentList, Block, ExprKind, Expression, ItemList, Pattern, PatternArgument, Register,
41    Statement, TraitName, TraitSpec, TypeParam, Unit,
42};
43use spade_types::KnownType;
44
45use constraints::{
46    bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource, TypeConstraints,
47};
48use equation::{TemplateTypeVarID, TypeEquations, TypeVar, TypeVarID, TypedExpression};
49use error::{
50    error_pattern_type_mismatch, Result, UnificationError, UnificationErrorExt, UnificationTrace,
51};
52use requirements::{Replacement, Requirement};
53use trace_stack::{format_trace_stack, TraceStackEntry};
54use traits::{TraitList, TraitReq};
55
56use crate::error::TypeMismatch as Tm;
57use crate::requirements::ConstantInt;
58use crate::traits::{TraitImpl, TraitImplList};
59
60mod constraints;
61pub mod dump;
62pub mod equation;
63pub mod error;
64pub mod expression;
65pub mod fixed_types;
66pub mod method_resolution;
67pub mod mir_type_lowering;
68mod replacement;
69mod requirements;
70pub mod testutil;
71pub mod trace_stack;
72pub mod traits;
73
74pub struct Context<'a> {
75    pub symtab: &'a SymbolTable,
76    pub items: &'a ItemList,
77    pub trait_impls: &'a TraitImplList,
78}
79impl<'a> std::fmt::Debug for Context<'a> {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{{context omitted}}")
82    }
83}
84
85// NOTE(allow) This is a debug macro which is not normally used but can come in handy
86#[allow(unused_macros)]
87macro_rules! add_trace {
88    ($self:expr, $($arg : tt) *) => {
89        $self.trace_stack.push(TraceStack::Message(format!($($arg)*)))
90    }
91}
92
93#[derive(Debug)]
94pub enum GenericListSource<'a> {
95    /// For when you just need a new generic list but have no need to refer back
96    /// to it in the future
97    Anonymous,
98    /// For managing the generics of callable with the specified name when type checking
99    /// their body
100    Definition(&'a NameID),
101    ImplBlock {
102        target: &'a ImplTarget,
103        id: ImplID,
104    },
105    /// For expressions which instantiate generic items
106    Expression(ExprID),
107}
108
109/// Stored version of GenericListSource
110#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize)]
111pub enum GenericListToken {
112    Anonymous(usize),
113    Definition(NameID),
114    ImplBlock(ImplTarget, ImplID),
115    Expression(ExprID),
116}
117
118#[derive(Debug)]
119pub struct TurbofishCtx<'a> {
120    turbofish: &'a Loc<ArgumentList<TypeExpression>>,
121    prev_generic_list: &'a GenericListToken,
122    type_ctx: &'a Context<'a>,
123}
124
125#[derive(Clone, Serialize, Deserialize)]
126pub struct PipelineState {
127    current_stage_depth: TypeVarID,
128    total_depth: Loc<TypeVarID>,
129    pipeline_loc: Loc<()>,
130}
131
132/// State of the type inference algorithm
133#[derive(Clone, Serialize, Deserialize)]
134pub struct TypeState {
135    /// All types are referred to by their index to allow type vars changing inside
136    /// the type state while the types are "out in the wild". The TypeVarID is an index
137    /// into this type_vars list which is used to look up the actual type as currently
138    /// seen by the type state
139    type_vars: Vec<TypeVar>,
140    /// This key is used to prevent bugs when multiple type states are mixed. Each TypeVarID
141    /// holds the value of the key of the type state which created it, and this is checked
142    /// to ensure that type vars are not mixed. The key is initialized randomly on type
143    /// state creation
144    key: u64,
145    /// A type state can also support keys from other sources, this is tracked here
146    keys: BTreeSet<u64>,
147
148    equations: TypeEquations,
149    // NOTE: This is kind of redundant, we could use TypeVarIDs instead of having dedicated
150    // numbers for unknown types.
151    next_typeid: RefCell<u64>,
152    // List of the mapping between generic parameters and type vars.
153    // The key is the index of the expression for which this generic list is associated. (if this
154    // is a generic list for a call whose expression id is x to f<A, B>, then generic_lists[x] will
155    // be {A: <type var>, b: <type var>}
156    // Managed here because unification must update *all* TypeVars in existence.
157    generic_lists: HashMap<GenericListToken, HashMap<NameID, TypeVarID>>,
158
159    constraints: TypeConstraints,
160
161    /// Requirements which must be fulfilled but which do not guide further type inference.
162    /// For example, if seeing `let y = x.a` before knowing the type of `x`, a requirement is
163    /// added to say "x has field a, and y should be the type of that field"
164    #[serde(skip)]
165    requirements: Vec<Requirement>,
166
167    replacements: ReplacementStack,
168
169    /// The type var containing the depth of the pipeline stage we are currently in
170    pipeline_state: Option<PipelineState>,
171
172    pub trait_impls: TraitImplList,
173
174    #[serde(skip)]
175    pub trace_stack: TraceStack,
176
177    #[serde(skip)]
178    pub diags: DiagList,
179}
180
181impl TypeState {
182    /// Create a fresh type state, in most cases, this should be .create_child of an
183    /// existing type state
184    pub fn fresh() -> Self {
185        let key = fastrand::u64(..);
186        Self {
187            type_vars: vec![],
188            key,
189            keys: [key].into_iter().collect(),
190            equations: HashMap::new(),
191            next_typeid: RefCell::new(0),
192            trace_stack: TraceStack::new(),
193            constraints: TypeConstraints::new(),
194            requirements: vec![],
195            replacements: ReplacementStack::new(),
196            generic_lists: HashMap::new(),
197            trait_impls: TraitImplList::new(),
198            pipeline_state: None,
199            diags: DiagList::new(),
200        }
201    }
202
203    pub fn create_child(&self) -> Self {
204        let mut result = self.clone();
205        result.key = fastrand::u64(..);
206        result.keys.insert(result.key);
207        result
208    }
209
210    pub fn add_type_var(&mut self, var: TypeVar) -> TypeVarID {
211        let idx = self.type_vars.len();
212        self.type_vars.push(var);
213        TypeVarID {
214            inner: idx,
215            type_state_key: self.key,
216        }
217    }
218
219    pub fn get_equations(&self) -> &TypeEquations {
220        &self.equations
221    }
222
223    pub fn get_constraints(&self) -> &TypeConstraints {
224        &self.constraints
225    }
226
227    // Get a generic list with a safe unwrap since a token is acquired
228    pub fn get_generic_list<'a>(
229        &'a self,
230        generic_list_token: &'a GenericListToken,
231    ) -> Option<&'a HashMap<NameID, TypeVarID>> {
232        self.generic_lists.get(generic_list_token)
233    }
234
235    #[tracing::instrument(level = "trace", skip_all)]
236    fn hir_type_expr_to_var<'a>(
237        &'a mut self,
238        e: &Loc<hir::TypeExpression>,
239        generic_list_token: &GenericListToken,
240    ) -> Result<TypeVarID> {
241        let id = match &e.inner {
242            hir::TypeExpression::Integer(i) => self.add_type_var(TypeVar::Known(
243                e.loc(),
244                KnownType::Integer(i.clone()),
245                vec![],
246            )),
247            hir::TypeExpression::TypeSpec(spec) => {
248                self.type_var_from_hir(e.loc(), &spec.clone(), generic_list_token)?
249            }
250            hir::TypeExpression::ConstGeneric(g) => {
251                let constraint = self.visit_const_generic(g, generic_list_token)?;
252
253                let tvar = self.new_generic_tlnumber(e.loc());
254                self.add_constraint(
255                    tvar.clone(),
256                    constraint,
257                    g.loc(),
258                    &tvar,
259                    ConstraintSource::Where,
260                );
261
262                tvar
263            }
264        };
265        Ok(id)
266    }
267
268    #[tracing::instrument(level = "trace", skip_all, fields(%hir_type))]
269    pub fn type_var_from_hir<'a>(
270        &'a mut self,
271        loc: Loc<()>,
272        hir_type: &crate::hir::TypeSpec,
273        generic_list_token: &GenericListToken,
274    ) -> Result<TypeVarID> {
275        let generic_list = self.get_generic_list(generic_list_token);
276        let var = match &hir_type {
277            hir::TypeSpec::Declared(base, params) => {
278                let params = params
279                    .iter()
280                    .map(|e| self.hir_type_expr_to_var(e, generic_list_token))
281                    .collect::<Result<Vec<_>>>()?;
282
283                self.add_type_var(TypeVar::Known(
284                    loc,
285                    KnownType::Named(base.inner.clone()),
286                    params,
287                ))
288            }
289            hir::TypeSpec::Generic(name) => match generic_list
290                .ok_or_else(|| diag_anyhow!(loc, "Found no generic list for {name}"))?
291                .get(&name.inner)
292            {
293                Some(t) => t.clone(),
294                None => {
295                    for list_source in self.generic_lists.keys() {
296                        info!("Generic lists exist for {list_source:?}");
297                    }
298                    info!("Current source is {generic_list_token:?}");
299                    panic!("No entry in generic list for {name:?}");
300                }
301            },
302            hir::TypeSpec::Tuple(inner) => {
303                let inner = inner
304                    .iter()
305                    .map(|t| self.type_var_from_hir(loc, t, generic_list_token))
306                    .collect::<Result<_>>()?;
307                self.add_type_var(TypeVar::tuple(loc, inner))
308            }
309            hir::TypeSpec::Array { inner, size } => {
310                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
311                let size = self.hir_type_expr_to_var(size, generic_list_token)?;
312
313                self.add_type_var(TypeVar::array(loc, inner, size))
314            }
315            hir::TypeSpec::Wire(inner) => {
316                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
317                self.add_type_var(TypeVar::wire(loc, inner))
318            }
319            hir::TypeSpec::Inverted(inner) => {
320                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
321                self.add_type_var(TypeVar::inverted(loc, inner))
322            }
323            hir::TypeSpec::Wildcard(_) => self.new_generic_any(),
324            hir::TypeSpec::TraitSelf(_) => {
325                diag_bail!(
326                    loc,
327                    "Trying to convert TraitSelf to type inference type var"
328                )
329            }
330        };
331
332        Ok(var)
333    }
334
335    /// Returns the type of the expression with the specified id. Error if no equation
336    /// for the specified expression exists
337    pub fn type_of(&self, expr: &TypedExpression) -> TypeVarID {
338        if let Some(t) = self.equations.get(expr) {
339            *t
340        } else {
341            panic!("Tried looking up the type of {expr:?} but it was not found")
342        }
343    }
344
345    pub fn maybe_type_of(&self, expr: &TypedExpression) -> Option<&TypeVarID> {
346        self.equations.get(expr)
347    }
348
349    pub fn new_generic_int(&mut self, loc: Loc<()>, symtab: &SymbolTable) -> TypeVar {
350        TypeVar::Known(loc, t_int(symtab), vec![self.new_generic_tluint(loc)])
351    }
352
353    pub fn new_concrete_int(&mut self, size: BigUint, loc: Loc<()>) -> TypeVarID {
354        TypeVar::Known(loc, KnownType::Integer(size.to_bigint()), vec![]).insert(self)
355    }
356
357    /// Return a new generic int. The first returned value is int<N>, and the second
358    /// value is N
359    pub fn new_split_generic_int(
360        &mut self,
361        loc: Loc<()>,
362        symtab: &SymbolTable,
363    ) -> (TypeVarID, TypeVarID) {
364        let size = self.new_generic_tlint(loc);
365        let full = self.add_type_var(TypeVar::Known(loc, t_int(symtab), vec![size.clone()]));
366        (full, size)
367    }
368
369    pub fn new_split_generic_uint(
370        &mut self,
371        loc: Loc<()>,
372        symtab: &SymbolTable,
373    ) -> (TypeVarID, TypeVarID) {
374        let size = self.new_generic_tluint(loc);
375        let full = self.add_type_var(TypeVar::Known(loc, t_uint(symtab), vec![size.clone()]));
376        (full, size)
377    }
378
379    pub fn new_generic_with_meta(&mut self, loc: Loc<()>, meta: MetaType) -> TypeVarID {
380        let id = self.new_typeid();
381        self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), meta))
382    }
383
384    pub fn new_generic_type(&mut self, loc: Loc<()>) -> TypeVarID {
385        let id = self.new_typeid();
386        self.add_type_var(TypeVar::Unknown(
387            loc,
388            id,
389            TraitList::empty(),
390            MetaType::Type,
391        ))
392    }
393
394    pub fn new_generic_any(&mut self) -> TypeVarID {
395        let id = self.new_typeid();
396        // NOTE: ().nowhere() is fine here, we can never fail to unify with MetaType::Any so
397        // this loc will never be displayed
398        self.add_type_var(TypeVar::Unknown(
399            ().nowhere(),
400            id,
401            TraitList::empty(),
402            MetaType::Any,
403        ))
404    }
405
406    pub fn new_generic_tlbool(&mut self, loc: Loc<()>) -> TypeVarID {
407        let id = self.new_typeid();
408        self.add_type_var(TypeVar::Unknown(
409            loc,
410            id,
411            TraitList::empty(),
412            MetaType::Bool,
413        ))
414    }
415
416    pub fn new_generic_tluint(&mut self, loc: Loc<()>) -> TypeVarID {
417        let id = self.new_typeid();
418        self.add_type_var(TypeVar::Unknown(
419            loc,
420            id,
421            TraitList::empty(),
422            MetaType::Uint,
423        ))
424    }
425
426    pub fn new_generic_tlint(&mut self, loc: Loc<()>) -> TypeVarID {
427        let id = self.new_typeid();
428        self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Int))
429    }
430
431    pub fn new_generic_tlnumber(&mut self, loc: Loc<()>) -> TypeVarID {
432        let id = self.new_typeid();
433        self.add_type_var(TypeVar::Unknown(
434            loc,
435            id,
436            TraitList::empty(),
437            MetaType::Number,
438        ))
439    }
440
441    pub fn new_generic_number(&mut self, loc: Loc<()>, ctx: &Context) -> (TypeVarID, TypeVarID) {
442        let number = ctx
443            .symtab
444            .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
445            .expect("Did not find number in symtab")
446            .0;
447        let id = self.new_typeid();
448        let size = self.new_generic_tluint(loc);
449        let t = TraitReq {
450            name: TraitName::Named(number.nowhere()),
451            type_params: vec![size.clone()],
452        }
453        .nowhere();
454        (
455            self.add_type_var(TypeVar::Unknown(
456                loc,
457                id,
458                TraitList::from_vec(vec![t]),
459                MetaType::Type,
460            )),
461            size,
462        )
463    }
464
465    pub fn new_generic_with_traits(&mut self, loc: Loc<()>, traits: TraitList) -> TypeVarID {
466        let id = self.new_typeid();
467        self.add_type_var(TypeVar::Unknown(loc, id, traits, MetaType::Type))
468    }
469
470    /// Returns the current pipeline state. `access_loc` is *not* used to specify where to get the
471    /// state from, it is only used for the Diagnostic::bug that gets emitted if there is no
472    /// pipeline state
473    pub fn get_pipeline_state<T>(&self, access_loc: &Loc<T>) -> Result<&PipelineState> {
474        self.pipeline_state
475            .as_ref()
476            .ok_or_else(|| diag_anyhow!(access_loc, "Expected to have a pipeline state"))
477    }
478
479    /// Visit a unit but before doing any preprocessing run the `pp` function.
480    pub fn visit_unit_with_preprocessing(
481        &mut self,
482        entity: &Loc<Unit>,
483        pp: impl Fn(&mut TypeState, &Loc<Unit>, &GenericListToken, &Context) -> Result<()>,
484        ctx: &Context,
485    ) -> Result<()> {
486        self.trait_impls = ctx.trait_impls.clone();
487
488        let generic_list = self.create_generic_list(
489            GenericListSource::Definition(&entity.name.name_id().inner),
490            &entity.head.unit_type_params,
491            &entity.head.scope_type_params,
492            None,
493            // NOTE: I'm not 100% sure we need to pass these here, the information
494            // is probably redundant
495            &entity.head.where_clauses,
496        )?;
497
498        pp(self, entity, &generic_list, ctx)?;
499
500        // Add equations for the inputs
501        for (name, t) in &entity.inputs {
502            let tvar = self.type_var_from_hir(t.loc(), t, &generic_list)?;
503            self.add_equation(TypedExpression::Name(name.inner.clone()), tvar)
504        }
505
506        if let UnitKind::Pipeline {
507            depth,
508            depth_typeexpr_id,
509        } = &entity.head.unit_kind.inner
510        {
511            let depth_var = self.hir_type_expr_to_var(depth, &generic_list)?;
512            self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth_var.clone());
513            self.pipeline_state = Some(PipelineState {
514                current_stage_depth: self.add_type_var(TypeVar::Known(
515                    entity.head.unit_kind.loc(),
516                    KnownType::Integer(BigInt::zero()),
517                    vec![],
518                )),
519                pipeline_loc: entity.loc(),
520                total_depth: depth_var.clone().at_loc(depth),
521            });
522            self.add_requirement(Requirement::PositivePipelineDepth {
523                depth: depth_var.at_loc(depth),
524            });
525            TypedExpression::Name(entity.inputs[0].0.clone().inner)
526                .unify_with(&self.t_clock(entity.head.unit_kind.loc(), ctx.symtab), self)
527                .commit(self, ctx)
528                .into_diagnostic(
529                    entity.inputs[0].1.loc(),
530                    |diag,
531                     Tm {
532                         g: got,
533                         e: _expected,
534                     }| {
535                        diag.message(format!(
536                            "First pipeline argument must be a clock. Got {}",
537                            got.display(self)
538                        ))
539                        .primary_label("expected clock")
540                    },
541                    self,
542                )?;
543            // In order to catch negative depth early when they are specified as literals,
544            // we'll instantly check requirements here
545            self.check_requirements(false, ctx)?;
546        }
547
548        self.visit_expression(&entity.body, ctx, &generic_list);
549
550        // Ensure that the output type matches what the user specified, and unit otherwise
551        if let Some(output_type) = &entity.head.output_type {
552            let tvar = self.type_var_from_hir(output_type.loc(), output_type, &generic_list)?;
553
554            self.trace_stack.push(TraceStackEntry::Message(format!(
555                "Unifying with output type {}",
556                tvar.debug_resolve(self)
557            )));
558            self.unify(&TypedExpression::Id(entity.body.inner.id), &tvar, ctx)
559                .into_diagnostic_no_expected_source(
560                    &entity.body,
561                    |diag,
562                     Tm {
563                         g: got,
564                         e: expected,
565                     }| {
566                        let expected = expected.display(self);
567                        let got = got.display(self);
568                        // FIXME: Might want to check if there is no return value in the block
569                        // and in that case suggest adding it.
570                        diag.message(format!(
571                            "Output type mismatch. Expected {expected}, got {got}"
572                        ))
573                        .primary_label(format!("Found type {got}"))
574                        .secondary_label(output_type, format!("{expected} type specified here"))
575                    },
576                    self,
577                )?;
578        } else {
579            // No output type, so unify with the unit type.
580            TypedExpression::Id(entity.body.inner.id)
581                .unify_with(&self.add_type_var(TypeVar::unit(entity.head.name.loc())), self)
582                .commit(self, ctx)
583                .into_diagnostic_no_expected_source(entity.body.loc(), |diag, Tm{g: got, e: _expected}| {
584                    diag.message("Output type mismatch")
585                        .primary_label(format!("Found type {got}", got = got.display(self)))
586                        .note(format!(
587                            "The {} does not specify a return type.\nAdd a return type, or remove the return value.",
588                            entity.head.unit_kind.name()
589                        ))
590                }, self)?;
591        }
592
593        if let Some(PipelineState {
594            current_stage_depth,
595            pipeline_loc,
596            total_depth,
597        }) = self.pipeline_state.clone()
598        {
599            self.unify(&total_depth.inner, &current_stage_depth, ctx)
600                .into_diagnostic_no_expected_source(
601                    pipeline_loc,
602                    |diag, tm| {
603                        let (e, g) = tm.display_e_g(self);
604                        diag.message(format!("Pipeline depth mismatch. Expected {g} got {e}"))
605                            .primary_label(format!("Found {e} stages in this pipeline"))
606                    },
607                    self,
608                )?;
609        }
610
611        self.check_requirements(true, ctx)?;
612
613        // NOTE: We may accidentally leak a stage depth if this function returns early. However,
614        // since we only use stage depths in pipelines where we re-set it when we enter,
615        // accidentally leaving a depth in place should not trigger any *new* compiler bugs, but
616        // may help track something down
617        self.pipeline_state = None;
618
619        Ok(())
620    }
621
622    #[trace_typechecker]
623    #[tracing::instrument(level = "trace", skip_all, fields(%entity.name))]
624    pub fn visit_unit(&mut self, entity: &Loc<Unit>, ctx: &Context) -> Result<()> {
625        self.visit_unit_with_preprocessing(entity, |_, _, _, _| Ok(()), ctx)
626    }
627
628    #[trace_typechecker]
629    #[tracing::instrument(level = "trace", skip_all)]
630    fn visit_argument_list(
631        &mut self,
632        args: &Loc<ArgumentList<Expression>>,
633        ctx: &Context,
634        generic_list: &GenericListToken,
635    ) -> Result<()> {
636        for expr in args.expressions() {
637            self.visit_expression(expr, ctx, generic_list);
638        }
639        Ok(())
640    }
641
642    #[trace_typechecker]
643    fn type_check_argument_list(
644        &mut self,
645        args: &[Argument<Expression, TypeSpec>],
646        ctx: &Context,
647        generic_list: &GenericListToken,
648    ) -> Result<()> {
649        for Argument {
650            target,
651            target_type,
652            value,
653            kind,
654        } in args.iter()
655        {
656            let target_type = self.type_var_from_hir(value.loc(), target_type, generic_list)?;
657
658            let loc = match kind {
659                hir::param_util::ArgumentKind::Positional => value.loc(),
660                hir::param_util::ArgumentKind::Named => value.loc(),
661                hir::param_util::ArgumentKind::ShortNamed => target.loc(),
662            };
663
664            self.unify(&value.inner, &target_type, ctx)
665                .into_diagnostic(
666                    loc,
667                    |d, tm| {
668                        let (expected, got) = tm.display_e_g(self);
669                        d.message(format!(
670                            "Argument type mismatch. Expected {expected} got {got}"
671                        ))
672                        .primary_label(format!("expected {expected}"))
673                    },
674                    self,
675                )?;
676        }
677
678        Ok(())
679    }
680
681    #[trace_typechecker]
682    pub fn visit_expression_result(
683        &mut self,
684        expression: &Loc<Expression>,
685        ctx: &Context,
686        generic_list: &GenericListToken,
687        new_type: TypeVarID,
688    ) -> Result<()> {
689        // Recurse down the expression
690        match &expression.inner.kind {
691            ExprKind::Error => {
692                new_type
693                    .unify_with(&self.t_err(expression.loc()), self)
694                    .commit(self, ctx)
695                    .unwrap();
696            }
697            ExprKind::Identifier(_) => self.visit_identifier(expression, ctx)?,
698            ExprKind::TypeLevelInteger(_) => {
699                self.visit_type_level_integer(expression, generic_list, ctx)?
700            }
701            ExprKind::IntLiteral(_, _) => self.visit_int_literal(expression, ctx)?,
702            ExprKind::BoolLiteral(_) => self.visit_bool_literal(expression, ctx)?,
703            ExprKind::BitLiteral(_) => self.visit_bit_literal(expression, ctx)?,
704            ExprKind::TupleLiteral(_) => self.visit_tuple_literal(expression, ctx, generic_list)?,
705            ExprKind::TupleIndex(_, _) => self.visit_tuple_index(expression, ctx, generic_list)?,
706            ExprKind::ArrayLiteral(_) => self.visit_array_literal(expression, ctx, generic_list)?,
707            ExprKind::ArrayShorthandLiteral(_, _) => {
708                self.visit_array_shorthand_literal(expression, ctx, generic_list)?
709            }
710            ExprKind::CreatePorts => self.visit_create_ports(expression, ctx, generic_list)?,
711            ExprKind::FieldAccess(_, _) => {
712                self.visit_field_access(expression, ctx, generic_list)?
713            }
714            ExprKind::MethodCall { .. } => self.visit_method_call(expression, ctx, generic_list)?,
715            ExprKind::Index(_, _) => self.visit_index(expression, ctx, generic_list)?,
716            ExprKind::RangeIndex { .. } => self.visit_range_index(expression, ctx, generic_list)?,
717            ExprKind::Block(_) => self.visit_block_expr(expression, ctx, generic_list)?,
718            ExprKind::If(_, _, _) => self.visit_if(expression, ctx, generic_list)?,
719            ExprKind::Match(_, _) => self.visit_match(expression, ctx, generic_list)?,
720            ExprKind::BinaryOperator(_, _, _) => {
721                self.visit_binary_operator(expression, ctx, generic_list)?
722            }
723            ExprKind::UnaryOperator(_, _) => {
724                self.visit_unary_operator(expression, ctx, generic_list)?
725            }
726            ExprKind::Call {
727                kind,
728                callee,
729                args,
730                turbofish,
731            } => {
732                let head = ctx.symtab.unit_by_id(&callee.inner);
733
734                self.handle_function_like(
735                    expression.map_ref(|e| e.id),
736                    &expression.get_type(self),
737                    &FunctionLikeName::Free(callee.inner.clone()),
738                    &head,
739                    kind,
740                    args,
741                    ctx,
742                    true,
743                    false,
744                    turbofish.as_ref().map(|turbofish| TurbofishCtx {
745                        turbofish,
746                        prev_generic_list: generic_list,
747                        type_ctx: ctx,
748                    }),
749                    generic_list,
750                )?;
751            }
752            ExprKind::PipelineRef { .. } => {
753                self.visit_pipeline_ref(expression, generic_list, ctx)?;
754            }
755            ExprKind::StageReady | ExprKind::StageValid => {
756                expression
757                    .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
758                    .commit(self, ctx)
759                    .into_default_diagnostic(expression, self)?;
760            }
761
762            ExprKind::TypeLevelIf(cond, on_true, on_false) => {
763                let cond_var = self.visit_const_generic_with_id(
764                    cond,
765                    generic_list,
766                    ConstraintSource::TypeLevelIf,
767                    ctx,
768                )?;
769                let t_bool = self.new_generic_tlbool(cond.loc());
770                self.unify(&cond_var, &t_bool, ctx).into_diagnostic(
771                    cond,
772                    |diag, tm| {
773                        let (_e, g) = tm.display_e_g(self);
774                        diag.message(format!("gen if conditions must be #bool, got {g}"))
775                    },
776                    self,
777                )?;
778
779                self.visit_expression(on_true, ctx, generic_list);
780                self.visit_expression(on_false, ctx, generic_list);
781
782                self.unify_expression_generic_error(expression, on_true.as_ref(), ctx)?;
783                self.unify_expression_generic_error(expression, on_false.as_ref(), ctx)?;
784            }
785            ExprKind::LambdaDef {
786                arguments,
787                body,
788                lambda_type,
789                lambda_type_params,
790                captured_generic_params,
791                lambda_unit: _,
792            } => {
793                for arg in arguments {
794                    self.visit_pattern(arg, ctx, generic_list)?;
795                }
796
797                self.visit_expression(body, ctx, generic_list);
798
799                let lambda_params = arguments
800                    .iter()
801                    .map(|arg| arg.get_type(self))
802                    .chain(vec![body.get_type(self)])
803                    .chain(
804                        captured_generic_params
805                            .iter()
806                            .map(|cap| {
807                                let t = self
808                                    .get_generic_list(generic_list)
809                                    .ok_or_else(|| {
810                                        diag_anyhow!(
811                                            expression,
812                                            "Found a captured generic but no generic list"
813                                        )
814                                    })?
815                                    .get(&cap.name_in_body)
816                                    .ok_or_else(|| {
817                                        diag_anyhow!(
818                                            &cap.name_in_body,
819                                            "Did not find an entry for {} in lambda generic list",
820                                            cap.name_in_body
821                                        )
822                                    });
823                                Ok(t?.clone())
824                            })
825                            .collect::<Result<Vec<_>>>()?
826                            .into_iter(),
827                    )
828                    .collect::<Vec<_>>();
829
830                let self_type = TypeVar::Known(
831                    expression.loc(),
832                    KnownType::Named(lambda_type.clone()),
833                    lambda_params.clone(),
834                );
835
836                let unit_generic_list = self.create_generic_list(
837                    GenericListSource::Expression(expression.id),
838                    &lambda_type_params,
839                    &[],
840                    None,
841                    &[],
842                )?;
843
844                for (p, tp) in lambda_params.iter().zip(lambda_type_params) {
845                    let gl = self.get_generic_list(&unit_generic_list).unwrap();
846                    // Safe unwrap, we're just unifying unknowns with unknowns
847                    p.unify_with(
848                        gl.get(&tp.name_id).ok_or_else(|| {
849                            diag_anyhow!(
850                                expression,
851                                "Lambda unit list did not contain {}",
852                                tp.name_id
853                            )
854                        })?,
855                        self,
856                    )
857                    .commit(self, ctx)
858                    .into_default_diagnostic(expression, self)?;
859                }
860                expression
861                    .unify_with(&self.add_type_var(self_type), self)
862                    .commit(self, ctx)
863                    .into_default_diagnostic(expression, self)?;
864            }
865            ExprKind::StaticUnreachable(_) => {}
866            ExprKind::Null => {}
867        }
868        Ok(())
869    }
870
871    #[tracing::instrument(level = "trace", skip_all)]
872    pub fn visit_expression(
873        &mut self,
874        expression: &Loc<Expression>,
875        ctx: &Context,
876        generic_list: &GenericListToken,
877    ) {
878        let new_type = self.new_generic_type(expression.loc());
879        self.add_equation(TypedExpression::Id(expression.inner.id), new_type);
880
881        match self.visit_expression_result(expression, ctx, generic_list, new_type) {
882            Ok(_) => {}
883            Err(e) => {
884                new_type
885                    .unify_with(&self.t_err(expression.loc()), self)
886                    .commit(self, ctx)
887                    .unwrap();
888
889                self.diags.errors.push(e);
890            }
891        }
892    }
893
894    // Common handler for entities, functions and pipelines
895    #[tracing::instrument(level = "trace", skip_all, fields(%name))]
896    #[trace_typechecker]
897    fn handle_function_like(
898        &mut self,
899        expression_id: Loc<ExprID>,
900        expression_type: &TypeVarID,
901        name: &FunctionLikeName,
902        head: &Loc<UnitHead>,
903        call_kind: &CallKind,
904        args: &Loc<ArgumentList<Expression>>,
905        ctx: &Context,
906        // Whether or not to visit the argument expressions passed to the function here. This
907        // should not be done if the expressoins have been visited before, for example, when
908        // handling methods
909        visit_args: bool,
910        // If we are calling a method, we have an implicit self argument which means
911        // that any error reporting number of arguments should be reduced by one
912        is_method: bool,
913        turbofish: Option<TurbofishCtx>,
914        generic_list: &GenericListToken,
915    ) -> Result<()> {
916        // Add new symbols for all the type parameters
917        let unit_generic_list = self.create_generic_list(
918            GenericListSource::Expression(expression_id.inner),
919            &head.unit_type_params,
920            &head.scope_type_params,
921            turbofish,
922            &head.where_clauses,
923        )?;
924
925        match (&head.unit_kind.inner, call_kind) {
926            (
927                UnitKind::Pipeline {
928                    depth: udepth,
929                    depth_typeexpr_id: _,
930                },
931                CallKind::Pipeline {
932                    inst_loc: _,
933                    depth: cdepth,
934                    depth_typeexpr_id: cdepth_typeexpr_id,
935                },
936            ) => {
937                let definition_depth = self.hir_type_expr_to_var(udepth, &unit_generic_list)?;
938                let call_depth = self.hir_type_expr_to_var(cdepth, generic_list)?;
939
940                // NOTE: We're not adding udepth_typeexpr_id here as that would break
941                // in the future if we try to do recursion. We will also never need to look
942                // up the depth in the definition
943                self.add_equation(TypedExpression::Id(*cdepth_typeexpr_id), call_depth.clone());
944
945                self.unify(&call_depth, &definition_depth, ctx)
946                    .into_diagnostic_no_expected_source(
947                        cdepth,
948                        |diag, tm| {
949                            let (e, g) = tm.display_e_g(self);
950                            diag.message("Pipeline depth mismatch")
951                                .primary_label(format!("Expected depth {e}, got {g}"))
952                                .secondary_label(udepth, format!("{name} has depth {e}"))
953                        },
954                        self,
955                    )?;
956            }
957            _ => {}
958        }
959
960        if visit_args {
961            self.visit_argument_list(args, ctx, &generic_list)?;
962        }
963
964        let type_params = &head.get_type_params();
965
966        // Special handling of built in functions
967        macro_rules! handle_special_functions {
968            ($([$($path:expr),*] => $handler:expr),*) => {
969                $(
970                    let path = Path(vec![$(Identifier($path.to_string()).nowhere()),*]).nowhere();
971                    if ctx.symtab
972                        .try_lookup_final_id(&path, &[])
973                        .map(|n| &FunctionLikeName::Free(n) == name)
974                        .unwrap_or(false)
975                    {
976                        $handler
977                    };
978                )*
979            }
980        }
981
982        // NOTE: These would be better as a closure, but unfortunately that does not satisfy
983        // the borrow checker
984        macro_rules! generic_arg {
985            ($idx:expr) => {
986                self.get_generic_list(&unit_generic_list)
987                    .ok_or_else(|| diag_anyhow!(expression_id, "Found no generic list for call"))?
988                    [&type_params[$idx].name_id()]
989                    .clone()
990            };
991        }
992
993        let matched_args =
994            match_args_with_params(args, &head.inputs.inner, is_method).map_err(|e| {
995                let diag: Diagnostic = e.into();
996                diag.secondary_label(
997                    head,
998                    format!("{kind} defined here", kind = head.unit_kind.name()),
999                )
1000            })?;
1001
1002        handle_special_functions! {
1003            ["std", "conv", "concat"] => {
1004                self.handle_concat(
1005                    expression_id,
1006                    generic_arg!(0),
1007                    generic_arg!(1),
1008                    generic_arg!(2),
1009                    &matched_args,
1010                    ctx
1011                )?
1012            },
1013            ["std", "conv", "trunc"] => {
1014                self.handle_trunc(
1015                    expression_id,
1016                    generic_arg!(0),
1017                    generic_arg!(1),
1018                    &matched_args,
1019                    ctx
1020                )?
1021            },
1022            ["std", "ops", "comb_div"] => {
1023                self.handle_comb_mod_or_div(
1024                    generic_arg!(0),
1025                    &matched_args,
1026                    ctx
1027                )?
1028            },
1029            ["std", "ops", "comb_mod"] => {
1030                self.handle_comb_mod_or_div(
1031                    generic_arg!(0),
1032                    &matched_args,
1033                    ctx
1034                )?
1035            },
1036            ["std", "mem", "clocked_memory"]  => {
1037                let num_elements = generic_arg!(0);
1038                let addr_size = generic_arg!(2);
1039
1040                self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1041            },
1042            // NOTE: the last argument of _init does not need special treatment, so
1043            // we can use the same code path for both for now
1044            ["std", "mem", "clocked_memory_init"]  => {
1045                let num_elements = generic_arg!(0);
1046                let addr_size = generic_arg!(2);
1047
1048                self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1049            },
1050            ["std", "mem", "read_memory"]  => {
1051                let addr_size = generic_arg!(0);
1052                let num_elements = generic_arg!(2);
1053
1054                self.handle_read_memory(num_elements, addr_size, &matched_args, ctx)?
1055            }
1056        };
1057
1058        // Unify the types of the arguments
1059        self.type_check_argument_list(&matched_args, ctx, &unit_generic_list)?;
1060
1061        let return_type = head
1062            .output_type
1063            .as_ref()
1064            .map(|o| self.type_var_from_hir(expression_id.loc(), o, &unit_generic_list))
1065            .transpose()?
1066            .unwrap_or_else(|| {
1067                self.add_type_var(TypeVar::Known(
1068                    expression_id.loc(),
1069                    KnownType::Tuple,
1070                    vec![],
1071                ))
1072            });
1073
1074        self.unify(expression_type, &return_type, ctx)
1075            .into_default_diagnostic(expression_id.loc(), self)?;
1076
1077        Ok(())
1078    }
1079
1080    pub fn handle_concat(
1081        &mut self,
1082        expression_id: Loc<ExprID>,
1083        source_lhs_ty: TypeVarID,
1084        source_rhs_ty: TypeVarID,
1085        source_result_ty: TypeVarID,
1086        args: &[Argument<Expression, TypeSpec>],
1087        ctx: &Context,
1088    ) -> Result<()> {
1089        let (lhs_type, lhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1090        let (rhs_type, rhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1091        let (result_type, result_size) = self.new_generic_number(expression_id.loc(), ctx);
1092        self.unify(&source_lhs_ty, &lhs_type, ctx)
1093            .into_default_diagnostic(args[0].value.loc(), self)?;
1094        self.unify(&source_rhs_ty, &rhs_type, ctx)
1095            .into_default_diagnostic(args[1].value.loc(), self)?;
1096        self.unify(&source_result_ty, &result_type, ctx)
1097            .into_default_diagnostic(expression_id.loc(), self)?;
1098
1099        // Result size is sum of input sizes
1100        self.add_constraint(
1101            result_size.clone(),
1102            ce_var(&lhs_size) + ce_var(&rhs_size),
1103            expression_id.loc(),
1104            &result_size,
1105            ConstraintSource::Concatenation,
1106        );
1107        self.add_constraint(
1108            lhs_size.clone(),
1109            ce_var(&result_size) + -ce_var(&rhs_size),
1110            args[0].value.loc(),
1111            &lhs_size,
1112            ConstraintSource::Concatenation,
1113        );
1114        self.add_constraint(
1115            rhs_size.clone(),
1116            ce_var(&result_size) + -ce_var(&lhs_size),
1117            args[1].value.loc(),
1118            &rhs_size,
1119            ConstraintSource::Concatenation,
1120        );
1121
1122        self.add_requirement(Requirement::SharedBase(vec![
1123            lhs_type.at_loc(args[0].value),
1124            rhs_type.at_loc(args[1].value),
1125            result_type.at_loc(&expression_id.loc()),
1126        ]));
1127        Ok(())
1128    }
1129
1130    pub fn handle_trunc(
1131        &mut self,
1132        expression_id: Loc<ExprID>,
1133        source_in_ty: TypeVarID,
1134        source_result_ty: TypeVarID,
1135        args: &[Argument<Expression, TypeSpec>],
1136        ctx: &Context,
1137    ) -> Result<()> {
1138        let (in_ty, _) = self.new_generic_number(expression_id.loc(), ctx);
1139        let (result_type, _) = self.new_generic_number(expression_id.loc(), ctx);
1140        self.unify(&source_in_ty, &in_ty, ctx)
1141            .into_default_diagnostic(args[0].value.loc(), self)?;
1142        self.unify(&source_result_ty, &result_type, ctx)
1143            .into_default_diagnostic(expression_id.loc(), self)?;
1144
1145        self.add_requirement(Requirement::SharedBase(vec![
1146            in_ty.at_loc(args[0].value),
1147            result_type.at_loc(&expression_id.loc()),
1148        ]));
1149        Ok(())
1150    }
1151
1152    pub fn handle_comb_mod_or_div(
1153        &mut self,
1154        n_ty: TypeVarID,
1155        args: &[Argument<Expression, TypeSpec>],
1156        ctx: &Context,
1157    ) -> Result<()> {
1158        let (num, _) = self.new_generic_number(args[0].value.loc(), ctx);
1159        self.unify(&n_ty, &num, ctx)
1160            .into_default_diagnostic(args[0].value.loc(), self)?;
1161        Ok(())
1162    }
1163
1164    pub fn handle_clocked_memory(
1165        &mut self,
1166        num_elements: TypeVarID,
1167        addr_size_arg: TypeVarID,
1168        args: &[Argument<Expression, TypeSpec>],
1169        ctx: &Context,
1170    ) -> Result<()> {
1171        // FIXME: When we support where clauses, we should move this
1172        // definition into the definition of clocked_memory
1173        let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1174        let arg1_loc = args[1].value.loc();
1175        let tup = TypeVar::tuple(
1176            args[1].value.loc(),
1177            vec![
1178                self.new_generic_type(arg1_loc),
1179                addr_type,
1180                self.new_generic_type(arg1_loc),
1181            ],
1182        );
1183        let port_type = TypeVar::array(
1184            arg1_loc,
1185            self.add_type_var(tup),
1186            self.new_generic_tluint(arg1_loc),
1187        )
1188        .insert(self);
1189
1190        self.add_constraint(
1191            addr_size.clone(),
1192            bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1193            args[1].value.loc(),
1194            &port_type,
1195            ConstraintSource::MemoryIndexing,
1196        );
1197
1198        // NOTE: Unwrap is safe, size is still generic at this point
1199        self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1200        self.unify_expression_generic_error(args[1].value, &port_type, ctx)?;
1201
1202        Ok(())
1203    }
1204
1205    pub fn handle_read_memory(
1206        &mut self,
1207        num_elements: TypeVarID,
1208        addr_size_arg: TypeVarID,
1209        args: &[Argument<Expression, TypeSpec>],
1210        ctx: &Context,
1211    ) -> Result<()> {
1212        let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1213
1214        self.add_constraint(
1215            addr_size.clone(),
1216            bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1217            args[1].value.loc(),
1218            &addr_type,
1219            ConstraintSource::MemoryIndexing,
1220        );
1221
1222        // NOTE: Unwrap is safe, size is still generic at this point
1223        self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1224
1225        Ok(())
1226    }
1227
1228    #[tracing::instrument(level = "trace", skip(self, turbofish, where_clauses))]
1229    pub fn create_generic_list(
1230        &mut self,
1231        source: GenericListSource,
1232        type_params: &[Loc<TypeParam>],
1233        scope_type_params: &[Loc<TypeParam>],
1234        turbofish: Option<TurbofishCtx>,
1235        where_clauses: &[Loc<WhereClause>],
1236    ) -> Result<GenericListToken> {
1237        let turbofish_params = if let Some(turbofish) = turbofish.as_ref() {
1238            if type_params.is_empty() {
1239                return Err(Diagnostic::error(
1240                    turbofish.turbofish,
1241                    "Turbofish on non-generic function",
1242                )
1243                .primary_label("Turbofish on non-generic function"));
1244            }
1245
1246            let matched_params =
1247                param_util::match_args_with_params(turbofish.turbofish, &type_params, false)?;
1248
1249            // We want this to be positional, but the params we get from matching are
1250            // named, transform it. We'll do some unwrapping here, but it is safe
1251            // because we know all params are present
1252            matched_params
1253                .iter()
1254                .map(|matched_param| {
1255                    let i = type_params
1256                        .iter()
1257                        .enumerate()
1258                        .find_map(|(i, param)| match &param.inner {
1259                            TypeParam {
1260                                ident,
1261                                name_id: _,
1262                                trait_bounds: _,
1263                                meta: _,
1264                            } => {
1265                                if ident == matched_param.target {
1266                                    Some(i)
1267                                } else {
1268                                    None
1269                                }
1270                            }
1271                        })
1272                        .unwrap();
1273                    (i, matched_param)
1274                })
1275                .sorted_by_key(|(i, _)| *i)
1276                .map(|(_, mp)| Some(mp.value))
1277                .collect::<Vec<_>>()
1278        } else {
1279            type_params.iter().map(|_| None).collect::<Vec<_>>()
1280        };
1281
1282        let mut inline_trait_bounds: Vec<Loc<WhereClause>> = vec![];
1283
1284        let scope_type_params = scope_type_params
1285            .iter()
1286            .map(|param| {
1287                let hir::TypeParam {
1288                    ident,
1289                    name_id,
1290                    trait_bounds,
1291                    meta,
1292                } = &param.inner;
1293                if !trait_bounds.is_empty() {
1294                    if let MetaType::Type = meta {
1295                        inline_trait_bounds.push(
1296                            WhereClause::Type {
1297                                target: name_id.clone().at_loc(ident),
1298                                traits: trait_bounds.clone(),
1299                            }
1300                            .at_loc(param),
1301                        );
1302                    } else {
1303                        return Err(Diagnostic::bug(param, "Trait bounds on generic int")
1304                            .primary_label("Trait bounds are only allowed on type parameters"));
1305                    }
1306                }
1307                Ok((
1308                    name_id.clone(),
1309                    self.new_generic_with_meta(param.loc(), meta.clone()),
1310                ))
1311            })
1312            .collect::<Result<Vec<_>>>()?;
1313
1314        let new_list = type_params
1315            .iter()
1316            .enumerate()
1317            .map(|(i, param)| {
1318                let hir::TypeParam {
1319                    ident,
1320                    name_id,
1321                    trait_bounds,
1322                    meta,
1323                } = &param.inner;
1324
1325                let t = self.new_generic_with_meta(param.loc(), meta.clone());
1326
1327                if let Some(tf) = &turbofish_params[i] {
1328                    let tf_ctx = turbofish.as_ref().unwrap();
1329                    let ty = self.hir_type_expr_to_var(tf, tf_ctx.prev_generic_list)?;
1330                    self.unify(&ty, &t, tf_ctx.type_ctx)
1331                        .into_default_diagnostic(param, self)?;
1332                }
1333
1334                if !trait_bounds.is_empty() {
1335                    if let MetaType::Type = meta {
1336                        inline_trait_bounds.push(
1337                            WhereClause::Type {
1338                                target: name_id.clone().at_loc(ident),
1339                                traits: trait_bounds.clone(),
1340                            }
1341                            .at_loc(param),
1342                        );
1343                    }
1344                    Ok((name_id.clone(), t))
1345                } else {
1346                    Ok((name_id.clone(), t))
1347                }
1348            })
1349            .collect::<Result<Vec<_>>>()?
1350            .into_iter()
1351            .chain(scope_type_params.into_iter())
1352            .map(|(name, t)| (name, t.clone()))
1353            .collect::<HashMap<_, _>>();
1354
1355        self.trace_stack.push(TraceStackEntry::NewGenericList(
1356            new_list
1357                .iter()
1358                .map(|(name, var)| (name.clone(), var.debug_resolve(self)))
1359                .collect(),
1360        ));
1361
1362        let token = self.add_mapped_generic_list(source, new_list.clone());
1363
1364        for constraint in where_clauses.iter().chain(inline_trait_bounds.iter()) {
1365            match &constraint.inner {
1366                WhereClause::Type { target, traits } => {
1367                    self.visit_trait_bounds(target, traits.as_slice(), &token)?;
1368                }
1369                WhereClause::Int { target, constraint } => {
1370                    let int_constraint = self.visit_const_generic(constraint, &token)?;
1371                    let tvar = new_list.get(target).ok_or_else(|| {
1372                        Diagnostic::error(
1373                            target,
1374                            format!("{target} is not a generic parameter on this unit"),
1375                        )
1376                        .primary_label("Not a generic parameter")
1377                    })?;
1378
1379                    self.add_constraint(
1380                        tvar.clone(),
1381                        int_constraint,
1382                        constraint.loc(),
1383                        &tvar,
1384                        ConstraintSource::Where,
1385                    );
1386                }
1387            }
1388        }
1389
1390        Ok(token)
1391    }
1392
1393    /// Adds a generic list with parameters already mapped to types
1394    pub fn add_mapped_generic_list(
1395        &mut self,
1396        source: GenericListSource,
1397        mapping: HashMap<NameID, TypeVarID>,
1398    ) -> GenericListToken {
1399        let reference = match source {
1400            GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1401            GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1402            GenericListSource::ImplBlock { target, id } => {
1403                GenericListToken::ImplBlock(target.clone(), id)
1404            }
1405            GenericListSource::Expression(id) => GenericListToken::Expression(id),
1406        };
1407
1408        if self
1409            .generic_lists
1410            .insert(reference.clone(), mapping)
1411            .is_some()
1412        {
1413            panic!("A generic list already existed for {reference:?}");
1414        }
1415        reference
1416    }
1417
1418    pub fn remove_generic_list(&mut self, source: GenericListSource) {
1419        let reference = match source {
1420            GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1421            GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1422            GenericListSource::ImplBlock { target, id } => {
1423                GenericListToken::ImplBlock(target.clone(), id)
1424            }
1425            GenericListSource::Expression(id) => GenericListToken::Expression(id),
1426        };
1427
1428        self.generic_lists.remove(&reference.clone());
1429    }
1430
1431    #[tracing::instrument(level = "trace", skip_all)]
1432    #[trace_typechecker]
1433    pub fn visit_block(
1434        &mut self,
1435        block: &Block,
1436        ctx: &Context,
1437        generic_list: &GenericListToken,
1438    ) -> Result<()> {
1439        for statement in &block.statements {
1440            self.visit_statement(statement, ctx, generic_list);
1441        }
1442        if let Some(result) = &block.result {
1443            self.visit_expression(result, ctx, generic_list);
1444        }
1445        Ok(())
1446    }
1447
1448    #[tracing::instrument(level = "trace", skip_all)]
1449    pub fn visit_impl_blocks(&mut self, item_list: &ItemList) -> TraitImplList {
1450        let mut trait_impls = TraitImplList::new();
1451        for (target, impls) in &item_list.impls {
1452            for ((trait_name, type_expressions), impl_block) in impls {
1453                let result = (|| {
1454                    let generic_list = self.create_generic_list(
1455                        GenericListSource::ImplBlock {
1456                            target,
1457                            id: impl_block.id,
1458                        },
1459                        &[],
1460                        impl_block.type_params.as_slice(),
1461                        None,
1462                        &[],
1463                    )?;
1464
1465                    let loc = trait_name
1466                        .name_loc()
1467                        .map(|n| ().at_loc(&n))
1468                        .unwrap_or(().at_loc(&impl_block));
1469
1470                    let trait_type_params = type_expressions
1471                        .iter()
1472                        .map(|param| {
1473                            Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1474                                &param.clone().at_loc(&loc),
1475                                &generic_list,
1476                            )?))
1477                        })
1478                        .collect::<Result<_>>()?;
1479
1480                    let target_type_params = impl_block
1481                        .target
1482                        .type_params()
1483                        .into_iter()
1484                        .map(|param| {
1485                            Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1486                                &param.clone().at_loc(&loc),
1487                                &generic_list,
1488                            )?))
1489                        })
1490                        .collect::<Result<_>>()?;
1491
1492                    trait_impls
1493                        .inner
1494                        .entry(target.clone())
1495                        .or_default()
1496                        .push(TraitImpl {
1497                            name: trait_name.clone(),
1498                            target_type_params,
1499                            trait_type_params,
1500
1501                            impl_block: impl_block.inner.clone(),
1502                        });
1503
1504                    Ok(())
1505                })();
1506
1507                match result {
1508                    Ok(()) => {}
1509                    Err(e) => self.diags.errors.push(e),
1510                }
1511            }
1512        }
1513
1514        trait_impls
1515    }
1516
1517    #[trace_typechecker]
1518    pub fn visit_pattern(
1519        &mut self,
1520        pattern: &Loc<Pattern>,
1521        ctx: &Context,
1522        generic_list: &GenericListToken,
1523    ) -> Result<()> {
1524        let new_type = self.new_generic_type(pattern.loc());
1525        self.add_equation(TypedExpression::Id(pattern.inner.id), new_type);
1526        match &pattern.inner.kind {
1527            hir::PatternKind::Integer(val) => {
1528                let (num_t, _) = &self.new_generic_number(pattern.loc(), ctx);
1529                self.add_requirement(Requirement::FitsIntLiteral {
1530                    value: ConstantInt::Literal(val.clone()),
1531                    target_type: num_t.clone().at_loc(pattern),
1532                });
1533                self.unify(pattern, num_t, ctx)
1534                    .expect("Failed to unify new_generic with int");
1535            }
1536            hir::PatternKind::Bool(_) => {
1537                pattern
1538                    .unify_with(&self.t_bool(pattern.loc(), ctx.symtab), self)
1539                    .commit(self, ctx)
1540                    .expect("Expected new_generic with boolean");
1541            }
1542            hir::PatternKind::Name { name, pre_declared } => {
1543                if !pre_declared {
1544                    self.add_equation(
1545                        TypedExpression::Name(name.clone().inner),
1546                        pattern.get_type(self),
1547                    );
1548                }
1549                self.unify(
1550                    &TypedExpression::Id(pattern.id),
1551                    &TypedExpression::Name(name.clone().inner),
1552                    ctx,
1553                )
1554                .into_default_diagnostic(name.loc(), self)?;
1555            }
1556            hir::PatternKind::Tuple(subpatterns) => {
1557                for pattern in subpatterns {
1558                    self.visit_pattern(pattern, ctx, generic_list)?;
1559                }
1560                let tuple_type = self.add_type_var(TypeVar::tuple(
1561                    pattern.loc(),
1562                    subpatterns
1563                        .iter()
1564                        .map(|pattern| {
1565                            let p_type = pattern.get_type(self);
1566                            Ok(p_type)
1567                        })
1568                        .collect::<Result<_>>()?,
1569                ));
1570
1571                self.unify(pattern, &tuple_type, ctx)
1572                    .expect("Unification of new_generic with tuple type cannot fail");
1573            }
1574            hir::PatternKind::Array(inner) => {
1575                for pattern in inner {
1576                    self.visit_pattern(pattern, ctx, generic_list)?;
1577                }
1578                if inner.len() == 0 {
1579                    return Err(
1580                        Diagnostic::error(pattern, "Empty array patterns are unsupported")
1581                            .primary_label("Empty array pattern"),
1582                    );
1583                } else {
1584                    let inner_t = inner[0].get_type(self);
1585
1586                    for pattern in inner.iter().skip(1) {
1587                        self.unify(pattern, &inner_t, ctx)
1588                            .into_default_diagnostic(pattern, self)?;
1589                    }
1590
1591                    pattern
1592                        .unify_with(
1593                            &TypeVar::Known(
1594                                pattern.loc(),
1595                                KnownType::Array,
1596                                vec![
1597                                    inner_t,
1598                                    self.add_type_var(TypeVar::Known(
1599                                        pattern.loc(),
1600                                        KnownType::Integer(inner.len().to_bigint()),
1601                                        vec![],
1602                                    )),
1603                                ],
1604                            )
1605                            .insert(self),
1606                            self,
1607                        )
1608                        .commit(self, ctx)
1609                        .into_default_diagnostic(pattern, self)?;
1610                }
1611            }
1612            hir::PatternKind::Type(name, args) => {
1613                let (condition_type, params, generic_list) =
1614                    match ctx.symtab.patternable_type_by_id(name).inner {
1615                        Patternable {
1616                            kind: PatternableKind::Enum,
1617                            params: _,
1618                        } => {
1619                            let enum_variant = ctx.symtab.enum_variant_by_id(name).inner;
1620                            let generic_list = self.create_generic_list(
1621                                GenericListSource::Anonymous,
1622                                &enum_variant.type_params,
1623                                &[],
1624                                None,
1625                                &[],
1626                            )?;
1627
1628                            let condition_type = self.type_var_from_hir(
1629                                pattern.loc(),
1630                                &enum_variant.output_type,
1631                                &generic_list,
1632                            )?;
1633
1634                            (condition_type, enum_variant.params, generic_list)
1635                        }
1636                        Patternable {
1637                            kind: PatternableKind::Struct,
1638                            params: _,
1639                        } => {
1640                            let s = ctx.symtab.struct_by_id(name).inner;
1641                            let generic_list = self.create_generic_list(
1642                                GenericListSource::Anonymous,
1643                                &s.type_params,
1644                                &[],
1645                                None,
1646                                &[],
1647                            )?;
1648
1649                            let condition_type =
1650                                self.type_var_from_hir(pattern.loc(), &s.self_type, &generic_list)?;
1651
1652                            (condition_type, s.params, generic_list)
1653                        }
1654                    };
1655
1656                self.unify(pattern, &condition_type, ctx)
1657                    .expect("Unification of new_generic with enum cannot fail");
1658
1659                for (
1660                    PatternArgument {
1661                        target,
1662                        value: pattern,
1663                        kind,
1664                    },
1665                    Parameter {
1666                        name: _,
1667                        ty: target_type,
1668                        no_mangle: _,
1669                        field_translator: _,
1670                    },
1671                ) in args.iter().zip(params.0.iter())
1672                {
1673                    self.visit_pattern(pattern, ctx, &generic_list)?;
1674                    let target_type =
1675                        self.type_var_from_hir(target_type.loc(), target_type, &generic_list)?;
1676
1677                    let loc = match kind {
1678                        hir::ArgumentKind::Positional => pattern.loc(),
1679                        hir::ArgumentKind::Named => pattern.loc(),
1680                        hir::ArgumentKind::ShortNamed => target.loc(),
1681                    };
1682
1683                    self.unify(pattern, &target_type, ctx).into_diagnostic(
1684                        loc,
1685                        |d, tm| {
1686                            let (expected, got) = tm.display_e_g(self);
1687                            d.message(format!(
1688                                "Argument type mismatch. Expected {expected} got {got}"
1689                            ))
1690                            .primary_label(format!("expected {expected}"))
1691                        },
1692                        self,
1693                    )?;
1694                }
1695            }
1696        }
1697        Ok(())
1698    }
1699
1700    #[trace_typechecker]
1701    pub fn visit_wal_trace(
1702        &mut self,
1703        trace: &Loc<WalTrace>,
1704        ctx: &Context,
1705        generic_list: &GenericListToken,
1706    ) -> Result<()> {
1707        let WalTrace { clk, rst } = &trace.inner;
1708        clk.as_ref()
1709            .map(|x| {
1710                self.visit_expression(x, ctx, generic_list);
1711                x.unify_with(&self.t_clock(trace.loc(), ctx.symtab), self)
1712                    .commit(self, ctx)
1713                    .into_default_diagnostic(x, self)
1714            })
1715            .transpose()?;
1716        rst.as_ref()
1717            .map(|x| {
1718                self.visit_expression(x, ctx, generic_list);
1719                x.unify_with(&self.t_bool(trace.loc(), ctx.symtab), self)
1720                    .commit(self, ctx)
1721                    .into_default_diagnostic(x, self)
1722            })
1723            .transpose()?;
1724        Ok(())
1725    }
1726
1727    #[trace_typechecker]
1728    pub fn visit_statement_error(
1729        &mut self,
1730        stmt: &Loc<Statement>,
1731        ctx: &Context,
1732        generic_list: &GenericListToken,
1733    ) -> Result<()> {
1734        match &stmt.inner {
1735            Statement::Error => {
1736                if let Some(current_stage_depth) =
1737                    self.pipeline_state.as_ref().map(|s| s.current_stage_depth)
1738                {
1739                    current_stage_depth
1740                        .unify_with(&self.t_err(stmt.loc()), self)
1741                        .commit(self, ctx)
1742                        .unwrap();
1743                }
1744                Ok(())
1745            }
1746            Statement::Binding(Binding {
1747                pattern,
1748                ty,
1749                value,
1750                wal_trace,
1751            }) => {
1752                trace!("Visiting `let {} = ..`", pattern.kind);
1753                self.visit_expression(value, ctx, generic_list);
1754
1755                self.visit_pattern(pattern, ctx, generic_list)
1756                    .handle_in(&mut self.diags);
1757
1758                self.unify(&TypedExpression::Id(pattern.id), value, ctx)
1759                    .into_diagnostic(
1760                        pattern.loc(),
1761                        error_pattern_type_mismatch(
1762                            ty.as_ref().map(|t| t.loc()).unwrap_or_else(|| value.loc()),
1763                            self,
1764                        ),
1765                        self,
1766                    )
1767                    .handle_in(&mut self.diags);
1768
1769                if let Some(t) = ty {
1770                    let tvar = self.type_var_from_hir(t.loc(), t, generic_list)?;
1771                    self.unify(&TypedExpression::Id(pattern.id), &tvar, ctx)
1772                        .into_default_diagnostic(value.loc(), self)
1773                        .handle_in(&mut self.diags);
1774                }
1775
1776                wal_trace
1777                    .as_ref()
1778                    .map(|wt| self.visit_wal_trace(wt, ctx, generic_list))
1779                    .transpose()
1780                    .handle_in(&mut self.diags);
1781
1782                Ok(())
1783            }
1784            Statement::Expression(expr) => {
1785                self.visit_expression(expr, ctx, generic_list);
1786                Ok(())
1787            }
1788            Statement::Register(reg) => self.visit_register(reg, ctx, generic_list),
1789            Statement::Declaration(names) => {
1790                for name in names {
1791                    let new_type = self.new_generic_type(name.loc());
1792                    self.add_equation(TypedExpression::Name(name.clone().inner), new_type);
1793                }
1794                Ok(())
1795            }
1796            Statement::PipelineRegMarker(extra) => {
1797                match extra {
1798                    Some(PipelineRegMarkerExtra::Condition(cond)) => {
1799                        self.visit_expression(cond, ctx, generic_list);
1800                        cond.unify_with(&self.t_bool(cond.loc(), ctx.symtab), self)
1801                            .commit(self, ctx)
1802                            .into_default_diagnostic(cond, self)?;
1803                    }
1804                    Some(PipelineRegMarkerExtra::Count {
1805                        count: _,
1806                        count_typeexpr_id: _,
1807                    }) => {}
1808                    None => {}
1809                }
1810
1811                let current_stage_depth = self
1812                    .pipeline_state
1813                    .clone()
1814                    .ok_or_else(|| {
1815                        diag_anyhow!(stmt, "Found a pipeline reg marker in a non-pipeline")
1816                    })?
1817                    .current_stage_depth;
1818
1819                let new_depth = self.new_generic_tlint(stmt.loc());
1820                let offset = match extra {
1821                    Some(PipelineRegMarkerExtra::Count {
1822                        count,
1823                        count_typeexpr_id,
1824                    }) => {
1825                        let var = self.hir_type_expr_to_var(count, generic_list)?;
1826                        self.add_equation(TypedExpression::Id(*count_typeexpr_id), var.clone());
1827                        var
1828                    }
1829                    Some(PipelineRegMarkerExtra::Condition(_)) | None => self.add_type_var(
1830                        TypeVar::Known(stmt.loc(), KnownType::Integer(1.to_bigint()), vec![]),
1831                    ),
1832                };
1833
1834                let total_depth = ConstraintExpr::Sum(
1835                    Box::new(ConstraintExpr::Var(offset)),
1836                    Box::new(ConstraintExpr::Var(current_stage_depth)),
1837                );
1838                self.pipeline_state
1839                    .as_mut()
1840                    .expect("Expected to have a pipeline state")
1841                    .current_stage_depth = new_depth.clone();
1842
1843                self.add_constraint(
1844                    new_depth.clone(),
1845                    total_depth,
1846                    stmt.loc(),
1847                    &new_depth,
1848                    ConstraintSource::PipelineRegCount {
1849                        reg: stmt.loc(),
1850                        total: self.get_pipeline_state(stmt)?.total_depth.loc(),
1851                    },
1852                );
1853
1854                Ok(())
1855            }
1856            Statement::Label(name) => {
1857                let key = TypedExpression::Name(name.inner.clone());
1858                let var = if !self.equations.contains_key(&key) {
1859                    let var = self.new_generic_tlint(name.loc());
1860                    self.trace_stack.push(TraceStackEntry::AddingPipelineLabel(
1861                        name.inner.clone(),
1862                        var.debug_resolve(self),
1863                    ));
1864                    self.add_equation(key.clone(), var.clone());
1865                    var
1866                } else {
1867                    let var = self.equations.get(&key).unwrap().clone();
1868                    self.trace_stack
1869                        .push(TraceStackEntry::RecoveringPipelineLabel(
1870                            name.inner.clone(),
1871                            var.debug_resolve(self),
1872                        ));
1873                    var
1874                };
1875                // Safe unwrap, unifying with a fresh var
1876                self.unify(
1877                    &var,
1878                    &self.get_pipeline_state(name)?.current_stage_depth.clone(),
1879                    ctx,
1880                )
1881                .unwrap();
1882                Ok(())
1883            }
1884            Statement::WalSuffixed { .. } => Ok(()),
1885            Statement::Assert(expr) => {
1886                self.visit_expression(expr, ctx, generic_list);
1887
1888                expr.unify_with(&self.t_bool(stmt.loc(), ctx.symtab), self)
1889                    .commit(self, ctx)
1890                    .into_default_diagnostic(expr, self)
1891                    .handle_in(&mut self.diags);
1892                Ok(())
1893            }
1894            Statement::Set { target, value } => {
1895                self.visit_expression(target, ctx, generic_list);
1896                self.visit_expression(value, ctx, generic_list);
1897
1898                let inner_type = self.new_generic_type(value.loc());
1899                let outer_type =
1900                    TypeVar::backward(stmt.loc(), inner_type.clone(), self).insert(self);
1901                self.unify_expression_generic_error(target, &outer_type, ctx)
1902                    .handle_in(&mut self.diags);
1903                self.unify_expression_generic_error(value, &inner_type, ctx)
1904                    .handle_in(&mut self.diags);
1905
1906                Ok(())
1907            }
1908        }
1909    }
1910
1911    pub fn visit_statement(
1912        &mut self,
1913        stmt: &Loc<Statement>,
1914        ctx: &Context,
1915        generic_list: &GenericListToken,
1916    ) {
1917        if let Err(e) = self.visit_statement_error(stmt, ctx, generic_list) {
1918            self.diags.errors.push(e);
1919        }
1920    }
1921
1922    #[trace_typechecker]
1923    pub fn visit_register(
1924        &mut self,
1925        reg: &Register,
1926        ctx: &Context,
1927        generic_list: &GenericListToken,
1928    ) -> Result<()> {
1929        self.visit_pattern(&reg.pattern, ctx, generic_list)?;
1930
1931        let type_spec_type = match &reg.value_type {
1932            Some(t) => Some(self.type_var_from_hir(t.loc(), t, generic_list)?.at_loc(t)),
1933            None => None,
1934        };
1935
1936        // We need to do this before visiting value, in case it constrains the
1937        // type of the identifiers in the pattern
1938        if let Some(tvar) = &type_spec_type {
1939            self.unify(&TypedExpression::Id(reg.pattern.id), tvar, ctx)
1940                .into_diagnostic_no_expected_source(
1941                    reg.pattern.loc(),
1942                    error_pattern_type_mismatch(tvar.loc(), self),
1943                    self,
1944                )?;
1945        }
1946
1947        self.visit_expression(&reg.clock, ctx, generic_list);
1948        self.visit_expression(&reg.value, ctx, generic_list);
1949
1950        if let Some(tvar) = &type_spec_type {
1951            self.unify(&reg.value, tvar, ctx)
1952                .into_default_diagnostic(reg.value.loc(), self)?;
1953        }
1954
1955        if let Some((rst_cond, rst_value)) = &reg.reset {
1956            self.visit_expression(rst_cond, ctx, generic_list);
1957            self.visit_expression(rst_value, ctx, generic_list);
1958            // Ensure cond is a boolean
1959            rst_cond
1960                .unify_with(&self.t_bool(rst_cond.loc(), ctx.symtab), self)
1961                .commit(self, ctx)
1962                .into_diagnostic(
1963                    rst_cond.loc(),
1964                    |diag,
1965                     Tm {
1966                         g: got,
1967                         e: _expected,
1968                     }| {
1969                        diag.message(format!(
1970                            "Register reset condition must be a bool, got {got}",
1971                            got = got.display(self)
1972                        ))
1973                        .primary_label("expected bool")
1974                    },
1975                    self,
1976                )?;
1977
1978            // Ensure the reset value has the same type as the register itself
1979            self.unify(&rst_value.inner, &reg.value.inner, ctx)
1980                .into_diagnostic(
1981                    rst_value.loc(),
1982                    |diag, tm| {
1983                        let (expected, got) = tm.display_e_g(self);
1984                        diag.message(format!(
1985                            "Register reset value mismatch. Expected {expected} got {got}"
1986                        ))
1987                        .primary_label(format!("expected {expected}"))
1988                        .secondary_label(&reg.pattern, format!("because this has type {expected}"))
1989                    },
1990                    self,
1991                )?;
1992        }
1993
1994        if let Some(initial) = &reg.initial {
1995            self.visit_expression(initial, ctx, generic_list);
1996
1997            self.unify(&initial.inner, &reg.value.inner, ctx)
1998                .into_diagnostic(
1999                    initial.loc(),
2000                    |diag, tm| {
2001                        let (expected, got) = tm.display_e_g(self);
2002                        diag.message(format!(
2003                            "Register initial value mismatch. Expected {expected} got {got}"
2004                        ))
2005                        .primary_label(format!("expected {expected}, got {got}"))
2006                        .secondary_label(&reg.pattern, format!("because this has type {got}"))
2007                    },
2008                    self,
2009                )?;
2010        }
2011
2012        reg.clock
2013            .unify_with(&self.t_clock(reg.clock.loc(), ctx.symtab), self)
2014            .commit(self, ctx)
2015            .into_diagnostic(
2016                reg.clock.loc(),
2017                |diag,
2018                 Tm {
2019                     g: got,
2020                     e: _expected,
2021                 }| {
2022                    diag.message(format!(
2023                        "Expected clock, got {got}",
2024                        got = got.display(self)
2025                    ))
2026                    .primary_label("expected clock")
2027                },
2028                self,
2029            )?;
2030
2031        self.unify(&TypedExpression::Id(reg.pattern.id), &reg.value, ctx)
2032            .into_diagnostic(
2033                reg.pattern.loc(),
2034                error_pattern_type_mismatch(reg.value.loc(), self),
2035                self,
2036            )?;
2037
2038        Ok(())
2039    }
2040
2041    #[trace_typechecker]
2042    pub fn visit_trait_spec(
2043        &mut self,
2044        trait_spec: &Loc<TraitSpec>,
2045        generic_list: &GenericListToken,
2046    ) -> Result<Loc<TraitReq>> {
2047        let type_params = if let Some(type_params) = &trait_spec.inner.type_params {
2048            type_params
2049                .inner
2050                .iter()
2051                .map(|te| self.hir_type_expr_to_var(te, generic_list))
2052                .collect::<Result<_>>()?
2053        } else {
2054            vec![]
2055        };
2056
2057        Ok(TraitReq {
2058            name: trait_spec.name.clone(),
2059            type_params,
2060        }
2061        .at_loc(trait_spec))
2062    }
2063
2064    #[trace_typechecker]
2065    pub fn visit_trait_bounds(
2066        &mut self,
2067        target: &Loc<NameID>,
2068        traits: &[Loc<TraitSpec>],
2069        generic_list_tok: &GenericListToken,
2070    ) -> Result<()> {
2071        let trait_reqs = traits
2072            .iter()
2073            .map(|spec| self.visit_trait_spec(spec, generic_list_tok))
2074            .collect::<Result<BTreeSet<_>>>()?
2075            .into_iter()
2076            .collect_vec();
2077
2078        if !trait_reqs.is_empty() {
2079            let trait_list = TraitList::from_vec(trait_reqs);
2080
2081            let generic_list = self.generic_lists.get(generic_list_tok).unwrap();
2082
2083            let Some(tvar) = generic_list.get(&target.inner) else {
2084                return Err(Diagnostic::bug(
2085                    target,
2086                    "Couldn't find generic from where clause in generic list",
2087                )
2088                .primary_label(format!(
2089                    "Generic type {} not found in generic list",
2090                    target.inner
2091                )));
2092            };
2093
2094            self.trace_stack.push(TraceStackEntry::AddingTraitBounds(
2095                tvar.debug_resolve(self),
2096                trait_list.clone(),
2097            ));
2098
2099            let TypeVar::Unknown(loc, id, old_trait_list, MetaType::Type) = tvar.resolve(self)
2100            else {
2101                return Err(Diagnostic::bug(
2102                    target,
2103                    "Trait bounds on known type or type-level integer",
2104                )
2105                .primary_label(format!(
2106                    "Trait bounds on {}, which should've been caught in ast-lowering",
2107                    target.inner
2108                )));
2109            };
2110
2111            let new_tvar = self.add_type_var(TypeVar::Unknown(
2112                *loc,
2113                *id,
2114                old_trait_list.clone().extend(trait_list),
2115                MetaType::Type,
2116            ));
2117
2118            trace!(
2119                "Adding trait bound {} on type {}",
2120                new_tvar.display_with_meta(true, self),
2121                target.inner
2122            );
2123
2124            let generic_list = self.generic_lists.get_mut(generic_list_tok).unwrap();
2125            generic_list.insert(target.inner.clone(), new_tvar);
2126        }
2127
2128        Ok(())
2129    }
2130
2131    pub fn visit_const_generic_with_id(
2132        &mut self,
2133        gen: &Loc<ConstGenericWithId>,
2134        generic_list_token: &GenericListToken,
2135        constraint_source: ConstraintSource,
2136        ctx: &Context,
2137    ) -> Result<TypeVarID> {
2138        let var = match &gen.inner.inner {
2139            ConstGeneric::Name(name) => {
2140                let ty = &ctx.symtab.type_symbol_by_id(&name);
2141                match &ty.inner {
2142                    TypeSymbol::Declared(_, _) => {
2143                        return Err(Diagnostic::error(
2144                            name,
2145                            "{type_decl_kind} cannot be used in a const generic expression",
2146                        )
2147                        .primary_label("Type in  const generic")
2148                        .secondary_label(ty, "{name} is declared here"))
2149                    }
2150                    TypeSymbol::GenericArg { .. } | TypeSymbol::GenericMeta(MetaType::Type) => {
2151                        return Err(Diagnostic::error(
2152                            name,
2153                            "Generic types cannot be used in const generic expressions",
2154                        )
2155                        .primary_label("Type in  const generic")
2156                        .secondary_label(ty, "{name} is declared here")
2157                        .span_suggest_insert_before(
2158                            "Consider making this a value",
2159                            ty.loc(),
2160                            "#uint ",
2161                        ))
2162                    }
2163                    TypeSymbol::GenericMeta(MetaType::Number) => {
2164                        self.new_generic_tlnumber(gen.loc())
2165                    }
2166                    TypeSymbol::GenericMeta(MetaType::Int) => self.new_generic_tlint(gen.loc()),
2167                    TypeSymbol::GenericMeta(MetaType::Uint) => self.new_generic_tluint(gen.loc()),
2168                    TypeSymbol::GenericMeta(MetaType::Bool) => self.new_generic_tlbool(gen.loc()),
2169                    TypeSymbol::GenericMeta(MetaType::Any) => {
2170                        diag_bail!(gen, "Found any meta type")
2171                    }
2172                    TypeSymbol::Alias(_) => {
2173                        return Err(Diagnostic::error(
2174                            gen,
2175                            "Aliases are not currently supported in const generics",
2176                        )
2177                        .secondary_label(ty, "Alias defined here"))
2178                    }
2179                }
2180            }
2181            ConstGeneric::Const(_)
2182            | ConstGeneric::Add(_, _)
2183            | ConstGeneric::Sub(_, _)
2184            | ConstGeneric::Mul(_, _)
2185            | ConstGeneric::Div(_, _)
2186            | ConstGeneric::Mod(_, _)
2187            | ConstGeneric::UintBitsToFit(_) => self.new_generic_tlnumber(gen.loc()),
2188            ConstGeneric::Eq(_, _) | ConstGeneric::NotEq(_, _) => {
2189                self.new_generic_tlbool(gen.loc())
2190            }
2191        };
2192        let constraint = self.visit_const_generic(&gen.inner.inner, generic_list_token)?;
2193        self.add_equation(TypedExpression::Id(gen.id), var.clone());
2194        self.add_constraint(var.clone(), constraint, gen.loc(), &var, constraint_source);
2195        Ok(var)
2196    }
2197
2198    #[trace_typechecker]
2199    pub fn visit_const_generic(
2200        &self,
2201        constraint: &ConstGeneric,
2202        generic_list: &GenericListToken,
2203    ) -> Result<ConstraintExpr> {
2204        let wrap = |lhs,
2205                    rhs,
2206                    wrapper: fn(Box<ConstraintExpr>, Box<ConstraintExpr>) -> ConstraintExpr|
2207         -> Result<_> {
2208            Ok(wrapper(
2209                Box::new(self.visit_const_generic(lhs, generic_list)?),
2210                Box::new(self.visit_const_generic(rhs, generic_list)?),
2211            ))
2212        };
2213        let constraint = match constraint {
2214            ConstGeneric::Name(n) => {
2215                let var = self
2216                    .get_generic_list(generic_list)
2217                    .ok_or_else(|| diag_anyhow!(n, "Found no generic list"))?
2218                    .get(n)
2219                    .ok_or_else(|| {
2220                        Diagnostic::bug(n, "Found non-generic argument in where clause")
2221                    })?;
2222                ConstraintExpr::Var(*var)
2223            }
2224            ConstGeneric::Const(val) => ConstraintExpr::Integer(val.clone()),
2225            ConstGeneric::Add(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Sum)?,
2226            ConstGeneric::Sub(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Difference)?,
2227            ConstGeneric::Mul(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Product)?,
2228            ConstGeneric::Div(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Div)?,
2229            ConstGeneric::Mod(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Mod)?,
2230            ConstGeneric::Eq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Eq)?,
2231            ConstGeneric::NotEq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::NotEq)?,
2232            ConstGeneric::UintBitsToFit(a) => ConstraintExpr::UintBitsToRepresent(Box::new(
2233                self.visit_const_generic(a, generic_list)?,
2234            )),
2235        };
2236        Ok(constraint)
2237    }
2238}
2239
2240// Private helper functions
2241impl TypeState {
2242    fn new_typeid(&self) -> u64 {
2243        let mut next = self.next_typeid.borrow_mut();
2244        let result = *next;
2245        *next += 1;
2246        result
2247    }
2248
2249    pub fn add_equation(&mut self, expression: TypedExpression, var: TypeVarID) {
2250        self.trace_stack.push(TraceStackEntry::AddingEquation(
2251            expression.clone(),
2252            var.debug_resolve(self),
2253        ));
2254        if let Some(prev) = self.equations.insert(expression.clone(), var.clone()) {
2255            let var = var.clone();
2256            let expr = expression.clone();
2257            println!("{}", format_trace_stack(self));
2258            panic!("Adding equation for {} == {} but a previous eq exists.\n\tIt was previously bound to {}", expr, var.debug_resolve(self), prev.debug_resolve(self))
2259        }
2260    }
2261
2262    fn add_constraint(
2263        &mut self,
2264        lhs: TypeVarID,
2265        rhs: ConstraintExpr,
2266        loc: Loc<()>,
2267        inside: &TypeVarID,
2268        source: ConstraintSource,
2269    ) {
2270        let replaces = lhs.clone();
2271        let rhs = rhs.with_context(&replaces, &inside, source).at_loc(&loc);
2272
2273        self.trace_stack.push(TraceStackEntry::AddingConstraint(
2274            lhs.debug_resolve(self),
2275            rhs.inner.clone(),
2276        ));
2277
2278        self.constraints.add_int_constraint(lhs, rhs);
2279    }
2280
2281    fn add_requirement(&mut self, requirement: Requirement) {
2282        self.trace_stack
2283            .push(TraceStackEntry::AddRequirement(requirement.clone()));
2284        self.requirements.push(requirement)
2285    }
2286
2287    /// Performs unification but does not update constraints. This is done to avoid
2288    /// updating constraints more often than necessary. Technically, constraints can
2289    /// be updated even less often, but `unify` is a pretty natural point to do so.
2290
2291    fn unify_inner(
2292        &mut self,
2293        e1: &impl HasType,
2294        e2: &impl HasType,
2295        ctx: &Context,
2296    ) -> std::result::Result<TypeVarID, UnificationError> {
2297        let v1 = e1.get_type(self);
2298        let v2 = e2.get_type(self);
2299
2300        trace!(
2301            "Unifying {} with {}",
2302            v1.debug_resolve(self),
2303            v2.debug_resolve(self)
2304        );
2305
2306        self.trace_stack.push(TraceStackEntry::TryingUnify(
2307            v1.debug_resolve(self),
2308            v2.debug_resolve(self),
2309        ));
2310
2311        macro_rules! err_producer {
2312            () => {{
2313                self.trace_stack
2314                    .push(TraceStackEntry::Message("Produced error".to_string()));
2315                UnificationError::Normal(Tm {
2316                    g: UnificationTrace::new(v1),
2317                    e: UnificationTrace::new(v2),
2318                })
2319            }};
2320        }
2321        macro_rules! meta_err_producer {
2322            () => {{
2323                self.trace_stack
2324                    .push(TraceStackEntry::Message("Produced error".to_string()));
2325                UnificationError::MetaMismatch(Tm {
2326                    g: UnificationTrace::new(v1),
2327                    e: UnificationTrace::new(v2),
2328                })
2329            }};
2330        }
2331
2332        macro_rules! unify_if {
2333            ($condition:expr, $new_type:expr, $replaced_type:expr) => {
2334                if $condition {
2335                    Ok(($new_type, $replaced_type))
2336                } else {
2337                    Err(err_producer!())
2338                }
2339            };
2340        }
2341
2342        let unify_params = |s: &mut Self,
2343                            p1: &[TypeVarID],
2344                            p2: &[TypeVarID]|
2345         -> std::result::Result<(), UnificationError> {
2346            if p1.len() != p2.len() {
2347                return Err({
2348                    s.trace_stack
2349                        .push(TraceStackEntry::Message("Produced error".to_string()));
2350                    UnificationError::Normal(Tm {
2351                        g: UnificationTrace::new(v1),
2352                        e: UnificationTrace::new(v2),
2353                    })
2354                });
2355            }
2356
2357            for (t1, t2) in p1.iter().zip(p2.iter()) {
2358                match s.unify_inner(t1, t2, ctx) {
2359                    Ok(result) => result,
2360                    Err(e) => {
2361                        s.trace_stack
2362                            .push(TraceStackEntry::Message("Adding context".to_string()));
2363                        return Err(e).add_context(v1.clone(), v2.clone());
2364                    }
2365                };
2366            }
2367            Ok(())
2368        };
2369
2370        // Figure out the most general type, and take note if we need to
2371        // do any replacement of the types in the rest of the state
2372        let result = match (
2373            &(v1, v1.resolve(self).clone()),
2374            &(v2, v2.resolve(self).clone()),
2375        ) {
2376            ((_, TypeVar::Known(_, KnownType::Error, _)), _) => Ok((v1, vec![v2])),
2377            (_, (_, TypeVar::Known(_, KnownType::Error, _))) => Ok((v2, vec![v1])),
2378            ((_, TypeVar::Known(_, t1, p1)), (_, TypeVar::Known(_, t2, p2))) => {
2379                match (t1, t2) {
2380                    (KnownType::Integer(val1), KnownType::Integer(val2)) => {
2381                        // NOTE: We get better error messages if we don't actually
2382                        // replace v1 with v2. Not entirely sure why, but since they already
2383                        // have the same known type, we're fine. The same applies
2384                        // to all known types
2385                        unify_if!(val1 == val2, v1, vec![])
2386                    }
2387                    (KnownType::Named(n1), KnownType::Named(n2)) => {
2388                        match (
2389                            &ctx.symtab.type_symbol_by_id(n1).inner,
2390                            &ctx.symtab.type_symbol_by_id(n2).inner,
2391                        ) {
2392                            (TypeSymbol::Declared(_, _), TypeSymbol::Declared(_, _)) => {
2393                                if n1 != n2 {
2394                                    return Err(err_producer!());
2395                                }
2396
2397                                let new_ts1 = ctx.symtab.type_symbol_by_id(n1).inner;
2398                                let new_ts2 = ctx.symtab.type_symbol_by_id(n2).inner;
2399                                unify_params(self, &p1, &p2)?;
2400                                unify_if!(new_ts1 == new_ts2, v1, vec![])
2401                            }
2402                            (TypeSymbol::Declared(_, _), TypeSymbol::GenericArg { traits }) => {
2403                                if !traits.is_empty() {
2404                                    todo!("Implement trait unifictaion");
2405                                }
2406                                Ok((v1, vec![]))
2407                            }
2408                            (TypeSymbol::GenericArg { traits }, TypeSymbol::Declared(_, _)) => {
2409                                if !traits.is_empty() {
2410                                    todo!("Implement trait unifictaion");
2411                                }
2412                                Ok((v2, vec![]))
2413                            }
2414                            (
2415                                TypeSymbol::GenericArg { traits: ltraits },
2416                                TypeSymbol::GenericArg { traits: rtraits },
2417                            ) => {
2418                                if !ltraits.is_empty() || !rtraits.is_empty() {
2419                                    todo!("Implement trait unifictaion");
2420                                }
2421                                Ok((v1, vec![]))
2422                            }
2423                            (TypeSymbol::Declared(_, _), TypeSymbol::GenericMeta(_)) => todo!(),
2424                            (TypeSymbol::GenericArg { traits: _ }, TypeSymbol::GenericMeta(_)) => {
2425                                todo!()
2426                            }
2427                            (TypeSymbol::GenericMeta(_), TypeSymbol::Declared(_, _)) => todo!(),
2428                            (TypeSymbol::GenericMeta(_), TypeSymbol::GenericArg { traits: _ }) => {
2429                                todo!()
2430                            }
2431                            (TypeSymbol::Alias(_), _) | (_, TypeSymbol::Alias(_)) => {
2432                                return Err(UnificationError::Specific(Diagnostic::bug(
2433                                    ().nowhere(),
2434                                    "Encountered a raw type alias during unification",
2435                                )))
2436                            }
2437                            (TypeSymbol::GenericMeta(_), TypeSymbol::GenericMeta(_)) => todo!(),
2438                        }
2439                    }
2440                    (KnownType::Array, KnownType::Array)
2441                    | (KnownType::Tuple, KnownType::Tuple)
2442                    | (KnownType::Wire, KnownType::Wire)
2443                    | (KnownType::Inverted, KnownType::Inverted) => {
2444                        // Note, replacements should only occur when a variable goes from Unknown
2445                        // to Known, not when the base is the same. Replacements take care
2446                        // of parameters. Therefore, None is returned here
2447                        unify_params(self, &p1, &p2)?;
2448                        Ok((v1, vec![]))
2449                    }
2450                    (_, _) => Err(err_producer!()),
2451                }
2452            }
2453            // Unknown with unknown requires merging traits
2454            (
2455                (_, TypeVar::Unknown(loc1, _, traits1, meta1)),
2456                (_, TypeVar::Unknown(loc2, _, traits2, meta2)),
2457            ) => {
2458                let new_loc = if meta1.is_more_concrete_than(meta2) {
2459                    loc1
2460                } else {
2461                    loc2
2462                };
2463                let new_t = match unify_meta(meta1, meta2) {
2464                    Some(meta @ MetaType::Any) | Some(meta @ MetaType::Number) => {
2465                        if traits1.inner.is_empty() || traits2.inner.is_empty() {
2466                            return Err(UnificationError::Specific(diag_anyhow!(
2467                                new_loc,
2468                                "Inferred an any meta-type with traits"
2469                            )));
2470                        }
2471                        self.new_generic_with_meta(*loc1, meta)
2472                    }
2473                    Some(MetaType::Type) => {
2474                        let new_trait_names = traits1
2475                            .inner
2476                            .iter()
2477                            .chain(traits2.inner.iter())
2478                            .map(|t| t.name.clone())
2479                            .collect::<BTreeSet<_>>()
2480                            .into_iter()
2481                            .collect::<Vec<_>>();
2482
2483                        let new_traits = new_trait_names
2484                            .iter()
2485                            .map(
2486                                |name| match (traits1.get_trait(name), traits2.get_trait(name)) {
2487                                    (Some(req1), Some(req2)) => {
2488                                        let new_params = req1
2489                                            .inner
2490                                            .type_params
2491                                            .iter()
2492                                            .zip(req2.inner.type_params.iter())
2493                                            .map(|(p1, p2)| self.unify(p1, p2, ctx))
2494                                            .collect::<std::result::Result<_, UnificationError>>(
2495                                            )?;
2496
2497                                        Ok(TraitReq {
2498                                            name: name.clone(),
2499                                            type_params: new_params,
2500                                        }
2501                                        .at_loc(req1))
2502                                    }
2503                                    (Some(t), None) => Ok(t.clone()),
2504                                    (None, Some(t)) => Ok(t.clone()),
2505                                    (None, None) => panic!("Found a trait but neither side has it"),
2506                                },
2507                            )
2508                            .collect::<std::result::Result<Vec<_>, UnificationError>>()?;
2509
2510                        self.new_generic_with_traits(*new_loc, TraitList::from_vec(new_traits))
2511                    }
2512                    Some(MetaType::Int) => self.new_generic_tlint(*new_loc),
2513                    Some(MetaType::Uint) => self.new_generic_tluint(*new_loc),
2514                    Some(MetaType::Bool) => self.new_generic_tlbool(*new_loc),
2515                    None => return Err(meta_err_producer!()),
2516                };
2517                Ok((new_t, vec![v1, v2]))
2518            }
2519            (
2520                (otherid, TypeVar::Known(loc, base, params)),
2521                (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2522            )
2523            | (
2524                (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2525                (otherid, TypeVar::Known(loc, base, params)),
2526            ) => {
2527                let trait_is_expected = match (&v1.resolve(self), &v2.resolve(self)) {
2528                    (TypeVar::Known(_, _, _), _) => true,
2529                    _ => false,
2530                };
2531
2532                let impls = self.ensure_impls(otherid, traits, trait_is_expected, ukloc, ctx)?;
2533
2534                self.trace_stack.push(TraceStackEntry::Message(
2535                    "Unifying trait_parameters".to_string(),
2536                ));
2537                let mut new_params = params.clone();
2538                for (trait_impl, trait_req) in impls {
2539                    let mut param_map = BTreeMap::new();
2540
2541                    for (l, r) in trait_req
2542                        .type_params
2543                        .iter()
2544                        .zip(trait_impl.trait_type_params)
2545                    {
2546                        let copy = r.make_copy_with_mapping(self, &mut param_map);
2547                        self.unify(l, &copy, ctx)?;
2548                    }
2549
2550                    new_params = trait_impl
2551                        .target_type_params
2552                        .iter()
2553                        .zip(new_params)
2554                        .map(|(l, r)| {
2555                            let copy = l.make_copy_with_mapping(self, &mut param_map);
2556                            self.unify(&copy, &r, ctx).add_context(*ukid, *otherid)
2557                        })
2558                        .collect::<std::result::Result<_, _>>()?
2559                }
2560
2561                match (base, meta) {
2562                    (KnownType::Error, _) => {
2563                        unreachable!()
2564                    }
2565                    // Any matches all types
2566                    (_, MetaType::Any)
2567                    // All types are types
2568                    | (KnownType::Named(_), MetaType::Type)
2569                    | (KnownType::Tuple, MetaType::Type)
2570                    | (KnownType::Array, MetaType::Type)
2571                    | (KnownType::Wire, MetaType::Type)
2572                    | (KnownType::Bool(_), MetaType::Bool)
2573                    | (KnownType::Inverted, MetaType::Type)
2574                    // Integers match ints and numbers
2575                    | (KnownType::Integer(_), MetaType::Int)
2576                    | (KnownType::Integer(_), MetaType::Number)
2577                    => {
2578                        let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2579
2580                        Ok((new, vec![otherid.clone(), ukid.clone()]))
2581                    },
2582                    // Integers match uints but only if the concrete integer is >= 0
2583                    (KnownType::Integer(val), MetaType::Uint)
2584                    => {
2585                        if val < &0.to_bigint() {
2586                            Err(meta_err_producer!())
2587                        } else {
2588                            let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2589
2590                            Ok((new, vec![otherid.clone(), ukid.clone()]))
2591                        }
2592                    },
2593
2594                    // Integer with type
2595                    (KnownType::Integer(_), MetaType::Type) => Err(meta_err_producer!()),
2596
2597                    // Bools only unify with any or bool
2598                    (_, MetaType::Bool) => Err(meta_err_producer!()),
2599                    (KnownType::Bool(_), _) => Err(meta_err_producer!()),
2600
2601                    // Type with integer
2602                    (KnownType::Named(_), MetaType::Int | MetaType::Number | MetaType::Uint)
2603                    | (KnownType::Tuple, MetaType::Int | MetaType::Uint | MetaType::Number)
2604                    | (KnownType::Array, MetaType::Int | MetaType::Uint | MetaType::Number)
2605                    | (KnownType::Wire, MetaType::Int | MetaType::Uint | MetaType::Number)
2606                    | (KnownType::Inverted, MetaType::Int | MetaType::Uint | MetaType::Number)
2607                    => Err(meta_err_producer!())
2608                }
2609            }
2610        };
2611
2612        let (new_type, replaced_types) = result?;
2613
2614        self.trace_stack.push(TraceStackEntry::Unified(
2615            v1.debug_resolve(self),
2616            v2.debug_resolve(self),
2617            new_type.debug_resolve(self),
2618            replaced_types
2619                .iter()
2620                .map(|v| v.debug_resolve(self))
2621                .collect(),
2622        ));
2623
2624        for replaced_type in &replaced_types {
2625            if v1.inner != v2.inner {
2626                let (from, to) = (replaced_type.get_type(self), new_type.get_type(self));
2627                self.replacements.insert(from, to);
2628                if let Err(rec) = self.check_type_for_recursion(to, &mut vec![]) {
2629                    let err_t = self.t_err(().nowhere());
2630                    self.replacements.insert(to, err_t);
2631                    return Err(UnificationError::RecursiveType(rec));
2632                }
2633            }
2634        }
2635
2636        Ok(new_type)
2637    }
2638
2639    pub fn can_unify(&mut self, e1: &impl HasType, e2: &impl HasType, ctx: &Context) -> bool {
2640        self.trace_stack
2641            .push(TraceStackEntry::Enter("Running can_unify".to_string()));
2642        let result = self.do_and_restore(|s| s.unify(e1, e2, ctx)).is_ok();
2643        self.trace_stack.push(TraceStackEntry::Exit);
2644        result
2645    }
2646
2647    #[tracing::instrument(level = "trace", skip_all)]
2648    pub fn unify(
2649        &mut self,
2650        e1: &impl HasType,
2651        e2: &impl HasType,
2652        ctx: &Context,
2653    ) -> std::result::Result<TypeVarID, UnificationError> {
2654        let new_type = self.unify_inner(e1, e2, ctx)?;
2655
2656        // With replacement done, some of our constraints may have been updated to provide
2657        // more type inference information. Try to do unification of those new constraints too
2658        loop {
2659            trace!("Updating constraints");
2660            // NOTE: Cloning here is kind of ugly
2661            let new_info;
2662            (self.constraints, new_info) = self
2663                .constraints
2664                .clone()
2665                .update_type_level_value_constraints(self);
2666
2667            if new_info.is_empty() {
2668                break;
2669            }
2670
2671            for constraint in new_info {
2672                trace!(
2673                    "Constraint replaces {} with {:?}",
2674                    constraint.inner.0.display(self),
2675                    constraint.inner.1
2676                );
2677
2678                let ((var, replacement), loc) = constraint.split_loc();
2679
2680                self.trace_stack
2681                    .push(TraceStackEntry::InferringFromConstraints(
2682                        var.debug_resolve(self),
2683                        replacement.val.clone(),
2684                    ));
2685
2686                // NOTE: safe unwrap. We already checked the constraint above
2687                let expected_type = self.add_type_var(TypeVar::Known(loc, replacement.val, vec![]));
2688                let result = self.unify_inner(&expected_type.clone().at_loc(&loc), &var, ctx);
2689                let is_meta_error = matches!(result, Err(UnificationError::MetaMismatch { .. }));
2690                match result {
2691                    Ok(_) => {}
2692                    Err(UnificationError::Normal(Tm { mut e, mut g }))
2693                    | Err(UnificationError::MetaMismatch(Tm { mut e, mut g })) => {
2694                        e.inside.replace(
2695                            replacement
2696                                .context
2697                                .inside
2698                                .replace_inside(var, e.failing, self),
2699                        );
2700                        g.inside.replace(
2701                            replacement
2702                                .context
2703                                .inside
2704                                .replace_inside(var, g.failing, self),
2705                        );
2706                        return Err(UnificationError::FromConstraints {
2707                            got: g,
2708                            expected: e,
2709                            source: replacement.context.source,
2710                            loc,
2711                            is_meta_error,
2712                        });
2713                    }
2714                    Err(
2715                        e @ UnificationError::FromConstraints { .. }
2716                        | e @ UnificationError::Specific { .. }
2717                        | e @ UnificationError::RecursiveType(_)
2718                        | e @ UnificationError::UnsatisfiedTraits { .. },
2719                    ) => return Err(e),
2720                };
2721            }
2722        }
2723
2724        Ok(new_type)
2725    }
2726
2727    fn check_type_for_recursion(
2728        &self,
2729        ty: TypeVarID,
2730        seen: &mut Vec<TypeVarID>,
2731    ) -> std::result::Result<(), String> {
2732        seen.push(ty);
2733        match ty.resolve(self) {
2734            TypeVar::Known(_, base, params) => {
2735                for (i, param) in params.iter().enumerate() {
2736                    if seen.contains(param) {
2737                        return Err("*".to_string());
2738                    }
2739
2740                    if let Err(rest) = self.check_type_for_recursion(*param, seen) {
2741                        let list = params
2742                            .iter()
2743                            .enumerate()
2744                            .map(|(j, _)| {
2745                                if j == i {
2746                                    rest.clone()
2747                                } else {
2748                                    "_".to_string()
2749                                }
2750                            })
2751                            .join(", ");
2752
2753                        match base {
2754                            KnownType::Error => {}
2755                            KnownType::Named(name_id) => {
2756                                return Err(format!("{name_id}<{}>", list));
2757                            }
2758                            KnownType::Bool(_) | KnownType::Integer(_) => {
2759                                unreachable!("Encountered recursive type level bool or int")
2760                            }
2761                            KnownType::Tuple => return Err(format!("({})", list)),
2762                            KnownType::Array => return Err(format!("[{}]", list)),
2763                            KnownType::Wire => return Err(format!("&{}", list)),
2764                            KnownType::Inverted => return Err(format!("inv {}", list)),
2765                        }
2766                    }
2767                }
2768            }
2769            TypeVar::Unknown(_, _, traits, _) => {
2770                for t in &traits.inner {
2771                    for param in &t.type_params {
2772                        if seen.contains(param) {
2773                            return Err("...".to_string());
2774                        }
2775
2776                        self.check_type_for_recursion(*param, seen)?;
2777                    }
2778                }
2779            }
2780        }
2781        seen.pop();
2782        Ok(())
2783    }
2784
2785    fn ensure_impls(
2786        &mut self,
2787        var: &TypeVarID,
2788        traits: &TraitList,
2789        trait_is_expected: bool,
2790        trait_list_loc: &Loc<()>,
2791        ctx: &Context,
2792    ) -> std::result::Result<Vec<(TraitImpl, TraitReq)>, UnificationError> {
2793        self.trace_stack.push(TraceStackEntry::EnsuringImpls(
2794            var.debug_resolve(self),
2795            traits.clone(),
2796            trait_is_expected,
2797        ));
2798
2799        let number = ctx
2800            .symtab
2801            .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
2802            .expect("Did not find number in symtab")
2803            .0;
2804
2805        macro_rules! error_producer {
2806            ($required_traits:expr) => {
2807                if trait_is_expected {
2808                    if $required_traits.inner.len() == 1
2809                        && $required_traits
2810                            .get_trait(&TraitName::Named(number.clone().nowhere()))
2811                            .is_some()
2812                    {
2813                        Err(UnificationError::Normal(Tm {
2814                            e: UnificationTrace::new(
2815                                self.new_generic_with_traits(*trait_list_loc, $required_traits),
2816                            ),
2817                            g: UnificationTrace::new(var.clone()),
2818                        }))
2819                    } else {
2820                        Err(UnificationError::UnsatisfiedTraits {
2821                            var: *var,
2822                            traits: $required_traits.inner,
2823                            target_loc: trait_list_loc.clone(),
2824                        })
2825                    }
2826                } else {
2827                    Err(UnificationError::Normal(Tm {
2828                        e: UnificationTrace::new(var.clone()),
2829                        g: UnificationTrace::new(
2830                            self.new_generic_with_traits(*trait_list_loc, $required_traits),
2831                        ),
2832                    }))
2833                }
2834            };
2835        }
2836
2837        match &var.resolve(self).clone() {
2838            TypeVar::Known(_, known, params) if known.into_impl_target().is_some() => {
2839                let Some(target) = known.into_impl_target() else {
2840                    unreachable!()
2841                };
2842
2843                let (impls, unsatisfied): (Vec<_>, Vec<_>) = traits
2844                    .inner
2845                    .iter()
2846                    .map(|trait_req| {
2847                        if let Some(impld) = self.trait_impls.inner.get(&target).cloned() {
2848                            // Get a list of implementations of this trait where the type
2849                            // parameters can match
2850                            let target_impls = impld
2851                                .iter()
2852                                .filter_map(|trait_impl| {
2853                                    self.checkpoint();
2854                                    let trait_params_match = trait_impl
2855                                        .trait_type_params
2856                                        .iter()
2857                                        .zip(trait_req.type_params.iter())
2858                                        .all(|(l, r)| {
2859                                            let l = l.make_copy(self);
2860                                            self.unify(&l, r, ctx).is_ok()
2861                                        });
2862
2863                                    let impl_params_match =
2864                                        trait_impl.target_type_params.iter().zip(params).all(
2865                                            |(l, r)| {
2866                                                let l = l.make_copy(self);
2867                                                self.unify(&l, r, ctx).is_ok()
2868                                            },
2869                                        );
2870                                    self.restore();
2871
2872                                    if trait_impl.name == trait_req.name
2873                                        && trait_params_match
2874                                        && impl_params_match
2875                                    {
2876                                        Some(trait_impl)
2877                                    } else {
2878                                        None
2879                                    }
2880                                })
2881                                .collect::<Vec<_>>();
2882
2883                            if target_impls.len() == 0 {
2884                                Ok(Either::Right(trait_req.clone()))
2885                            } else if target_impls.len() == 1 {
2886                                let target_impl = *target_impls.last().unwrap();
2887                                Ok(Either::Left((target_impl.clone(), trait_req.inner.clone())))
2888                            } else {
2889                                Err(UnificationError::Specific(diag_anyhow!(
2890                                    trait_req,
2891                                    "Found more than one impl of {} for {}",
2892                                    trait_req.display(self),
2893                                    var.display(self)
2894                                )))
2895                            }
2896                        } else {
2897                            Ok(Either::Right(trait_req.clone()))
2898                        }
2899                    })
2900                    .collect::<std::result::Result<Vec<_>, _>>()?
2901                    .into_iter()
2902                    .partition_map(|x| x);
2903
2904                if unsatisfied.is_empty() {
2905                    self.trace_stack.push(TraceStackEntry::Message(
2906                        "Ensuring impl successful".to_string(),
2907                    ));
2908                    Ok(impls)
2909                } else {
2910                    error_producer!(TraitList::from_vec(unsatisfied.clone()))
2911                }
2912            }
2913            TypeVar::Unknown(_, _, _, _) => {
2914                panic!("running ensure_impls on an unknown type")
2915            }
2916            _ => {
2917                if traits.inner.is_empty() {
2918                    Ok(vec![])
2919                } else {
2920                    error_producer!(traits.clone())
2921                }
2922            }
2923        }
2924    }
2925
2926    pub fn unify_expression_generic_error(
2927        &mut self,
2928        expr: &Loc<Expression>,
2929        other: &impl HasType,
2930        ctx: &Context,
2931    ) -> Result<TypeVarID> {
2932        self.unify(&expr.inner, other, ctx)
2933            .into_default_diagnostic(expr.loc(), self)
2934    }
2935
2936    pub fn check_requirements(&mut self, is_final_check: bool, ctx: &Context) -> Result<()> {
2937        // Once we are done type checking the rest of the entity, check all requirements
2938        loop {
2939            // Walk through all the requirements, checking each one. If the requirement
2940            // is still undetermined, take note to retain that id, otherwise store the
2941            // replacement to be performed
2942            let (retain, replacements_option): (Vec<_>, Vec<_>) = self
2943                .requirements
2944                .clone()
2945                .iter()
2946                .map(|req| match req.check(self, ctx)? {
2947                    requirements::RequirementResult::NoChange => Ok((true, None)),
2948                    requirements::RequirementResult::UnsatisfiedNow(diag) => {
2949                        if is_final_check {
2950                            Err(diag)
2951                        } else {
2952                            Ok((true, None))
2953                        }
2954                    }
2955                    requirements::RequirementResult::Satisfied(replacement) => {
2956                        self.trace_stack
2957                            .push(TraceStackEntry::ResolvedRequirement(req.clone()));
2958                        Ok((false, Some(replacement)))
2959                    }
2960                })
2961                .collect::<Result<Vec<_>>>()?
2962                .into_iter()
2963                .unzip();
2964
2965            let replacements = replacements_option
2966                .into_iter()
2967                .flatten()
2968                .flatten()
2969                .collect::<Vec<_>>();
2970
2971            // Drop all replacements that have now been applied
2972            self.requirements = self
2973                .requirements
2974                .drain(0..)
2975                .zip(retain)
2976                .filter_map(|(req, keep)| if keep { Some(req) } else { None })
2977                .collect();
2978
2979            if replacements.is_empty() {
2980                break;
2981            }
2982
2983            for Replacement { from, to, context } in replacements {
2984                self.unify(&to, &from, ctx).into_diagnostic_or_default(
2985                    from.loc(),
2986                    context,
2987                    self,
2988                )?;
2989            }
2990        }
2991
2992        Ok(())
2993    }
2994
2995    pub fn get_replacement(&self, var: &TypeVarID) -> TypeVarID {
2996        self.replacements.get(*var)
2997    }
2998
2999    pub fn do_and_restore<T, E>(
3000        &mut self,
3001        inner: impl Fn(&mut Self) -> std::result::Result<T, E>,
3002    ) -> std::result::Result<T, E> {
3003        self.checkpoint();
3004        let result = inner(self);
3005        self.restore();
3006        result
3007    }
3008
3009    /// Create a "checkpoint" to which we can restore later using `restore`. This acts
3010    /// like a stack, nested checkpoints are possible
3011    fn checkpoint(&mut self) {
3012        self.trace_stack
3013            .push(TraceStackEntry::Enter("Creating checkpoint".to_string()));
3014        self.replacements.push();
3015    }
3016
3017    fn restore(&mut self) {
3018        self.replacements.discard_top();
3019        self.trace_stack.push(TraceStackEntry::Exit);
3020    }
3021}
3022
3023impl TypeState {
3024    pub fn print_equations(&self) {
3025        for (lhs, rhs) in &self.equations {
3026            println!(
3027                "{} -> {}",
3028                format!("{lhs}").blue(),
3029                format!("{}", rhs.debug_resolve(self)).green()
3030            )
3031        }
3032
3033        println!("\nReplacments:");
3034
3035        for repl_stack in &self.replacements.all() {
3036            let replacements = { repl_stack.borrow().clone() };
3037            for (lhs, rhs) in replacements.iter().sorted() {
3038                println!(
3039                    "{} -> {} ({} -> {})",
3040                    format!("{}", lhs.inner).blue(),
3041                    format!("{}", rhs.inner).green(),
3042                    format!("{}", lhs.debug_resolve(self)).blue(),
3043                    format!("{}", rhs.debug_resolve(self)).green(),
3044                )
3045            }
3046            println!("---")
3047        }
3048
3049        println!("\n Requirements:");
3050
3051        for requirement in &self.requirements {
3052            println!("{:?}", requirement)
3053        }
3054
3055        println!()
3056    }
3057}
3058
3059#[must_use]
3060pub struct UnificationBuilder {
3061    lhs: TypeVarID,
3062    rhs: TypeVarID,
3063}
3064impl UnificationBuilder {
3065    pub fn commit(
3066        self,
3067        state: &mut TypeState,
3068        ctx: &Context,
3069    ) -> std::result::Result<TypeVarID, UnificationError> {
3070        state.unify(&self.lhs, &self.rhs, ctx)
3071    }
3072}
3073
3074pub trait HasType: std::fmt::Debug {
3075    fn get_type(&self, state: &TypeState) -> TypeVarID {
3076        self.try_get_type(state)
3077            .expect(&format!("Did not find a type for {self:?}"))
3078    }
3079
3080    fn try_get_type(&self, state: &TypeState) -> Option<TypeVarID> {
3081        let id = self.get_type_impl(state);
3082        id.map(|id| state.get_replacement(&id))
3083    }
3084
3085    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID>;
3086
3087    fn unify_with(&self, rhs: &dyn HasType, state: &TypeState) -> UnificationBuilder {
3088        UnificationBuilder {
3089            lhs: self.get_type(state),
3090            rhs: rhs.get_type(state),
3091        }
3092    }
3093}
3094
3095impl HasType for TypeVarID {
3096    fn get_type_impl(&self, _state: &TypeState) -> Option<TypeVarID> {
3097        Some(*self)
3098    }
3099}
3100impl HasType for Loc<TypeVarID> {
3101    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3102        self.inner.try_get_type(state)
3103    }
3104}
3105impl HasType for TypedExpression {
3106    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3107        state.maybe_type_of(self).cloned()
3108    }
3109}
3110impl HasType for Expression {
3111    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3112        state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3113    }
3114}
3115impl HasType for Loc<Expression> {
3116    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3117        state
3118            .maybe_type_of(&TypedExpression::Id(self.inner.id))
3119            .cloned()
3120    }
3121}
3122impl HasType for Pattern {
3123    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3124        state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3125    }
3126}
3127impl HasType for Loc<Pattern> {
3128    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3129        state
3130            .maybe_type_of(&TypedExpression::Id(self.inner.id))
3131            .cloned()
3132    }
3133}
3134impl HasType for NameID {
3135    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3136        state
3137            .maybe_type_of(&TypedExpression::Name(self.clone()))
3138            .cloned()
3139    }
3140}