spade_typeinference/
lib.rs

1// This algorithm is based off the excellent lecture here
2// https://www.youtube.com/watch?v=xJXcZp2vgLs
3//
4// The basic idea is to go through every node in the HIR tree, for every typed thing,
5// add an equation indicating a constraint on that thing. This can onlytydone once
6// and should be done by the visitor for that node. The visitor should then unify
7// types according to the rules of the node.
8
9use std::cell::RefCell;
10use std::cmp::PartialEq;
11use std::collections::{BTreeMap, BTreeSet, HashMap};
12
13use colored::Colorize;
14use fixed_types::{t_int, t_uint};
15use hir::expression::CallKind;
16use hir::{
17    param_util, Binding, ConstGeneric, Parameter, PipelineRegMarkerExtra, TypeExpression, TypeSpec,
18    UnitHead, UnitKind, WalTrace, WhereClause,
19};
20use itertools::{Either, Itertools};
21use method_resolution::{FunctionLikeName, IntoImplTarget};
22use num::{BigInt, BigUint, Zero};
23use replacement::ReplacementStack;
24use serde::{Deserialize, Serialize};
25use spade_common::id_tracker::{ExprID, ImplID};
26use spade_common::num_ext::InfallibleToBigInt;
27use spade_diagnostics::diag_list::{DiagList, ResultExt};
28use spade_diagnostics::{diag_anyhow, diag_bail, Diagnostic};
29use spade_macros::trace_typechecker;
30use spade_types::meta_types::{unify_meta, MetaType};
31use trace_stack::TraceStack;
32use tracing::{info, trace};
33
34use spade_common::location_info::{Loc, WithLocation};
35use spade_common::name::{Identifier, NameID, Path};
36use spade_hir::param_util::{match_args_with_params, Argument};
37use spade_hir::symbol_table::{Patternable, PatternableKind, SymbolTable, TypeSymbol};
38use spade_hir::{self as hir, ConstGenericWithId, ImplTarget};
39use spade_hir::{
40    ArgumentList, Block, ExprKind, Expression, ItemList, Pattern, PatternArgument, Register,
41    Statement, TraitName, TraitSpec, TypeParam, Unit,
42};
43use spade_types::KnownType;
44
45use constraints::{
46    bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource, TypeConstraints,
47};
48use equation::{TemplateTypeVarID, TypeEquations, TypeVar, TypeVarID, TypedExpression};
49use error::{
50    error_pattern_type_mismatch, Result, UnificationError, UnificationErrorExt, UnificationTrace,
51};
52use requirements::{Replacement, Requirement};
53use trace_stack::{format_trace_stack, TraceStackEntry};
54use traits::{TraitList, TraitReq};
55
56use crate::error::TypeMismatch as Tm;
57use crate::requirements::ConstantInt;
58use crate::traits::{TraitImpl, TraitImplList};
59
60mod constraints;
61pub mod dump;
62pub mod equation;
63pub mod error;
64pub mod expression;
65pub mod fixed_types;
66pub mod method_resolution;
67pub mod mir_type_lowering;
68mod replacement;
69mod requirements;
70pub mod testutil;
71pub mod trace_stack;
72pub mod traits;
73
74pub struct Context<'a> {
75    pub symtab: &'a SymbolTable,
76    pub items: &'a ItemList,
77    pub trait_impls: &'a TraitImplList,
78}
79impl<'a> std::fmt::Debug for Context<'a> {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{{context omitted}}")
82    }
83}
84
85// NOTE(allow) This is a debug macro which is not normally used but can come in handy
86#[allow(unused_macros)]
87macro_rules! add_trace {
88    ($self:expr, $($arg : tt) *) => {
89        $self.trace_stack.push(TraceStack::Message(format!($($arg)*)))
90    }
91}
92
93#[derive(Debug)]
94pub enum GenericListSource<'a> {
95    /// For when you just need a new generic list but have no need to refer back
96    /// to it in the future
97    Anonymous,
98    /// For managing the generics of callable with the specified name when type checking
99    /// their body
100    Definition(&'a NameID),
101    ImplBlock {
102        target: &'a ImplTarget,
103        id: ImplID,
104    },
105    /// For expressions which instantiate generic items
106    Expression(ExprID),
107}
108
109/// Stored version of GenericListSource
110#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize)]
111pub enum GenericListToken {
112    Anonymous(usize),
113    Definition(NameID),
114    ImplBlock(ImplTarget, ImplID),
115    Expression(ExprID),
116}
117
118#[derive(Debug)]
119pub struct TurbofishCtx<'a> {
120    turbofish: &'a Loc<ArgumentList<TypeExpression>>,
121    prev_generic_list: &'a GenericListToken,
122    type_ctx: &'a Context<'a>,
123}
124
125#[derive(Clone, Serialize, Deserialize)]
126pub struct PipelineState {
127    current_stage_depth: TypeVarID,
128    total_depth: Loc<TypeVarID>,
129    pipeline_loc: Loc<()>,
130}
131
132/// State of the type inference algorithm
133#[derive(Clone, Serialize, Deserialize)]
134pub struct TypeState {
135    /// All types are referred to by their index to allow type vars changing inside
136    /// the type state while the types are "out in the wild". The TypeVarID is an index
137    /// into this type_vars list which is used to look up the actual type as currently
138    /// seen by the type state
139    type_vars: Vec<TypeVar>,
140    /// This key is used to prevent bugs when multiple type states are mixed. Each TypeVarID
141    /// holds the value of the key of the type state which created it, and this is checked
142    /// to ensure that type vars are not mixed. The key is initialized randomly on type
143    /// state creation
144    key: u64,
145    /// A type state can also support keys from other sources, this is tracked here
146    keys: BTreeSet<u64>,
147
148    equations: TypeEquations,
149    // NOTE: This is kind of redundant, we could use TypeVarIDs instead of having dedicated
150    // numbers for unknown types.
151    next_typeid: RefCell<u64>,
152    // List of the mapping between generic parameters and type vars.
153    // The key is the index of the expression for which this generic list is associated. (if this
154    // is a generic list for a call whose expression id is x to f<A, B>, then generic_lists[x] will
155    // be {A: <type var>, b: <type var>}
156    // Managed here because unification must update *all* TypeVars in existence.
157    generic_lists: HashMap<GenericListToken, HashMap<NameID, TypeVarID>>,
158
159    constraints: TypeConstraints,
160
161    /// Requirements which must be fulfilled but which do not guide further type inference.
162    /// For example, if seeing `let y = x.a` before knowing the type of `x`, a requirement is
163    /// added to say "x has field a, and y should be the type of that field"
164    #[serde(skip)]
165    requirements: Vec<Requirement>,
166
167    replacements: ReplacementStack,
168    // Skipping serialization of these is fine as we only push checkpoints during
169    // trait resolution
170    #[serde(skip)]
171    checkpoints: Vec<(Vec<Requirement>, TypeConstraints)>,
172
173    /// The type var containing the depth of the pipeline stage we are currently in
174    pipeline_state: Option<PipelineState>,
175
176    /// An error type that can be accessed anywhere without mut access. This is an option
177    /// to facilitate safe initialization, in practice it can never be None
178    error_type: Option<TypeVarID>,
179
180    pub trait_impls: TraitImplList,
181
182    #[serde(skip)]
183    pub trace_stack: TraceStack,
184
185    #[serde(skip)]
186    pub diags: DiagList,
187}
188
189impl TypeState {
190    /// Create a fresh type state, in most cases, this should be .create_child of an
191    /// existing type state
192    pub fn fresh() -> Self {
193        let key = fastrand::u64(..);
194        let mut result = Self {
195            type_vars: vec![],
196            key,
197            keys: [key].into_iter().collect(),
198            equations: HashMap::new(),
199            next_typeid: RefCell::new(0),
200            trace_stack: TraceStack::new(),
201            constraints: TypeConstraints::new(),
202            requirements: vec![],
203            replacements: ReplacementStack::new(),
204            generic_lists: HashMap::new(),
205            trait_impls: TraitImplList::new(),
206            checkpoints: vec![],
207            error_type: None,
208            pipeline_state: None,
209            diags: DiagList::new(),
210        };
211        result.error_type =
212            Some(result.add_type_var(TypeVar::Known(().nowhere(), KnownType::Error, vec![])));
213        result
214    }
215
216    pub fn create_child(&self) -> Self {
217        let mut result = self.clone();
218        result.key = fastrand::u64(..);
219        result.keys.insert(result.key);
220        result
221    }
222
223    pub fn add_type_var(&mut self, var: TypeVar) -> TypeVarID {
224        let idx = self.type_vars.len();
225        self.type_vars.push(var);
226        TypeVarID {
227            inner: idx,
228            type_state_key: self.key,
229        }
230    }
231
232    pub fn get_equations(&self) -> &TypeEquations {
233        &self.equations
234    }
235
236    pub fn get_constraints(&self) -> &TypeConstraints {
237        &self.constraints
238    }
239
240    // Get a generic list with a safe unwrap since a token is acquired
241    pub fn get_generic_list<'a>(
242        &'a self,
243        generic_list_token: &'a GenericListToken,
244    ) -> Option<&'a HashMap<NameID, TypeVarID>> {
245        self.generic_lists.get(generic_list_token)
246    }
247
248    #[tracing::instrument(level = "trace", skip_all)]
249    fn hir_type_expr_to_var<'a>(
250        &'a mut self,
251        e: &Loc<hir::TypeExpression>,
252        generic_list_token: &GenericListToken,
253    ) -> Result<TypeVarID> {
254        let id = match &e.inner {
255            hir::TypeExpression::Integer(i) => self.add_type_var(TypeVar::Known(
256                e.loc(),
257                KnownType::Integer(i.clone()),
258                vec![],
259            )),
260            hir::TypeExpression::String(s) => self.add_type_var(TypeVar::Known(
261                e.loc(),
262                KnownType::String(s.clone()),
263                vec![],
264            )),
265            hir::TypeExpression::TypeSpec(spec) => {
266                self.type_var_from_hir(e.loc(), &spec.clone(), generic_list_token)?
267            }
268            hir::TypeExpression::ConstGeneric(g) => {
269                let constraint = self.visit_const_generic(g, generic_list_token)?;
270
271                let tvar = self.new_generic_tlnumber(e.loc());
272                self.add_constraint(
273                    tvar.clone(),
274                    constraint,
275                    g.loc(),
276                    &tvar,
277                    ConstraintSource::Where,
278                );
279
280                tvar
281            }
282        };
283        Ok(id)
284    }
285
286    #[tracing::instrument(level = "trace", skip_all, fields(%hir_type))]
287    pub fn type_var_from_hir<'a>(
288        &'a mut self,
289        loc: Loc<()>,
290        hir_type: &crate::hir::TypeSpec,
291        generic_list_token: &GenericListToken,
292    ) -> Result<TypeVarID> {
293        let generic_list = self.get_generic_list(generic_list_token);
294        let var = match &hir_type {
295            hir::TypeSpec::Declared(base, params) => {
296                let params = params
297                    .iter()
298                    .map(|e| self.hir_type_expr_to_var(e, generic_list_token))
299                    .collect::<Result<Vec<_>>>()?;
300
301                self.add_type_var(TypeVar::Known(
302                    loc,
303                    KnownType::Named(base.inner.clone()),
304                    params,
305                ))
306            }
307            hir::TypeSpec::Generic(name) => match generic_list
308                .ok_or_else(|| diag_anyhow!(loc, "Found no generic list for {name}"))?
309                .get(&name.inner)
310            {
311                Some(t) => t.clone(),
312                None => {
313                    for list_source in self.generic_lists.keys() {
314                        info!("Generic lists exist for {list_source:?}");
315                    }
316                    info!("Current source is {generic_list_token:?}");
317                    panic!("No entry in generic list for {name:?}");
318                }
319            },
320            hir::TypeSpec::Tuple(inner) => {
321                let inner = inner
322                    .iter()
323                    .map(|t| self.type_var_from_hir(loc, t, generic_list_token))
324                    .collect::<Result<_>>()?;
325                self.add_type_var(TypeVar::tuple(loc, inner))
326            }
327            hir::TypeSpec::Array { inner, size } => {
328                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
329                let size = self.hir_type_expr_to_var(size, generic_list_token)?;
330
331                self.add_type_var(TypeVar::array(loc, inner, size))
332            }
333            hir::TypeSpec::Wire(inner) => {
334                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
335                self.add_type_var(TypeVar::wire(loc, inner))
336            }
337            hir::TypeSpec::Inverted(inner) => {
338                let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
339                self.add_type_var(TypeVar::inverted(loc, inner))
340            }
341            hir::TypeSpec::Wildcard(_) => self.new_generic_any(),
342            hir::TypeSpec::TraitSelf(_) => {
343                diag_bail!(
344                    loc,
345                    "Trying to convert TraitSelf to type inference type var"
346                )
347            }
348        };
349
350        Ok(var)
351    }
352
353    /// Returns the type of the expression with the specified id. Error if no equation
354    /// for the specified expression exists
355    pub fn type_of(&self, expr: &TypedExpression) -> TypeVarID {
356        if let Some(t) = self.equations.get(expr) {
357            *t
358        } else {
359            panic!("Tried looking up the type of {expr:?} but it was not found")
360        }
361    }
362
363    pub fn maybe_type_of(&self, expr: &TypedExpression) -> Option<&TypeVarID> {
364        self.equations.get(expr)
365    }
366
367    pub fn new_generic_int(&mut self, loc: Loc<()>, symtab: &SymbolTable) -> TypeVar {
368        TypeVar::Known(loc, t_int(symtab), vec![self.new_generic_tluint(loc)])
369    }
370
371    pub fn new_concrete_int(&mut self, size: BigUint, loc: Loc<()>) -> TypeVarID {
372        TypeVar::Known(loc, KnownType::Integer(size.to_bigint()), vec![]).insert(self)
373    }
374
375    /// Return a new generic int. The first returned value is int<N>, and the second
376    /// value is N
377    pub fn new_split_generic_int(
378        &mut self,
379        loc: Loc<()>,
380        symtab: &SymbolTable,
381    ) -> (TypeVarID, TypeVarID) {
382        let size = self.new_generic_tlint(loc);
383        let full = self.add_type_var(TypeVar::Known(loc, t_int(symtab), vec![size.clone()]));
384        (full, size)
385    }
386
387    pub fn new_split_generic_uint(
388        &mut self,
389        loc: Loc<()>,
390        symtab: &SymbolTable,
391    ) -> (TypeVarID, TypeVarID) {
392        let size = self.new_generic_tluint(loc);
393        let full = self.add_type_var(TypeVar::Known(loc, t_uint(symtab), vec![size.clone()]));
394        (full, size)
395    }
396
397    pub fn new_generic_with_meta(&mut self, loc: Loc<()>, meta: MetaType) -> TypeVarID {
398        let id = self.new_typeid();
399        self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), meta))
400    }
401
402    pub fn new_generic_type(&mut self, loc: Loc<()>) -> TypeVarID {
403        let id = self.new_typeid();
404        self.add_type_var(TypeVar::Unknown(
405            loc,
406            id,
407            TraitList::empty(),
408            MetaType::Type,
409        ))
410    }
411
412    pub fn new_generic_any(&mut self) -> TypeVarID {
413        let id = self.new_typeid();
414        // NOTE: ().nowhere() is fine here, we can never fail to unify with MetaType::Any so
415        // this loc will never be displayed
416        self.add_type_var(TypeVar::Unknown(
417            ().nowhere(),
418            id,
419            TraitList::empty(),
420            MetaType::Any,
421        ))
422    }
423
424    pub fn new_generic_tlbool(&mut self, loc: Loc<()>) -> TypeVarID {
425        let id = self.new_typeid();
426        self.add_type_var(TypeVar::Unknown(
427            loc,
428            id,
429            TraitList::empty(),
430            MetaType::Bool,
431        ))
432    }
433
434    pub fn new_generic_tlstr(&mut self, loc: Loc<()>) -> TypeVarID {
435        let id = self.new_typeid();
436        self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Str))
437    }
438
439    pub fn new_generic_tluint(&mut self, loc: Loc<()>) -> TypeVarID {
440        let id = self.new_typeid();
441        self.add_type_var(TypeVar::Unknown(
442            loc,
443            id,
444            TraitList::empty(),
445            MetaType::Uint,
446        ))
447    }
448
449    pub fn new_generic_tlint(&mut self, loc: Loc<()>) -> TypeVarID {
450        let id = self.new_typeid();
451        self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Int))
452    }
453
454    pub fn new_generic_tlnumber(&mut self, loc: Loc<()>) -> TypeVarID {
455        let id = self.new_typeid();
456        self.add_type_var(TypeVar::Unknown(
457            loc,
458            id,
459            TraitList::empty(),
460            MetaType::Number,
461        ))
462    }
463
464    pub fn new_generic_number(&mut self, loc: Loc<()>, ctx: &Context) -> (TypeVarID, TypeVarID) {
465        let number = ctx
466            .symtab
467            .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
468            .expect("Did not find number in symtab")
469            .0;
470        let id = self.new_typeid();
471        let size = self.new_generic_tluint(loc);
472        let t = TraitReq {
473            name: TraitName::Named(number.nowhere()),
474            type_params: vec![size.clone()],
475        }
476        .nowhere();
477        (
478            self.add_type_var(TypeVar::Unknown(
479                loc,
480                id,
481                TraitList::from_vec(vec![t]),
482                MetaType::Type,
483            )),
484            size,
485        )
486    }
487
488    pub fn new_generic_with_traits(&mut self, loc: Loc<()>, traits: TraitList) -> TypeVarID {
489        let id = self.new_typeid();
490        self.add_type_var(TypeVar::Unknown(loc, id, traits, MetaType::Type))
491    }
492
493    /// Returns the current pipeline state. `access_loc` is *not* used to specify where to get the
494    /// state from, it is only used for the Diagnostic::bug that gets emitted if there is no
495    /// pipeline state
496    pub fn get_pipeline_state<T>(&self, access_loc: &Loc<T>) -> Result<&PipelineState> {
497        self.pipeline_state
498            .as_ref()
499            .ok_or_else(|| diag_anyhow!(access_loc, "Expected to have a pipeline state"))
500    }
501
502    /// Visit a unit but before doing any preprocessing run the `pp` function.
503    pub fn visit_unit_with_preprocessing(
504        &mut self,
505        entity: &Loc<Unit>,
506        pp: impl Fn(&mut TypeState, &Loc<Unit>, &GenericListToken, &Context) -> Result<()>,
507        ctx: &Context,
508    ) -> Result<()> {
509        self.trait_impls = ctx.trait_impls.clone();
510
511        let generic_list = self.create_generic_list(
512            GenericListSource::Definition(&entity.name.name_id().inner),
513            &entity.head.unit_type_params,
514            &entity.head.scope_type_params,
515            None,
516            // NOTE: I'm not 100% sure we need to pass these here, the information
517            // is probably redundant
518            &entity.head.where_clauses,
519        )?;
520
521        pp(self, entity, &generic_list, ctx)?;
522
523        // Add equations for the inputs
524        for (name, t) in &entity.inputs {
525            let tvar = self.type_var_from_hir(t.loc(), t, &generic_list)?;
526            self.add_equation(TypedExpression::Name(name.inner.clone()), tvar)
527        }
528
529        if let UnitKind::Pipeline {
530            depth,
531            depth_typeexpr_id,
532        } = &entity.head.unit_kind.inner
533        {
534            self.setup_pipeline_state(
535                &entity.head.unit_kind,
536                &entity.body,
537                &generic_list,
538                depth,
539                depth_typeexpr_id,
540            )?;
541
542            let clock_index = if entity.head.is_nonstatic_method {
543                1
544            } else {
545                0
546            };
547
548            TypedExpression::Name(entity.inputs[clock_index].0.clone().inner)
549                .unify_with(&self.t_clock(entity.head.unit_kind.loc(), ctx.symtab), self)
550                .commit(self, ctx)
551                .into_diagnostic(
552                    entity.inputs[0].1.loc(),
553                    |diag,
554                     Tm {
555                         g: got,
556                         e: _expected,
557                     }| {
558                        diag.message(format!(
559                            "First pipeline argument must be a clock. Got {}",
560                            got.display(self)
561                        ))
562                        .primary_label("expected clock")
563                    },
564                    self,
565                )?;
566            // In order to catch negative depth early when they are specified as literals,
567            // we'll instantly check requirements here
568            self.check_requirements(false, ctx)?;
569        }
570
571        self.visit_expression(&entity.body, ctx, &generic_list);
572
573        // Ensure that the output type matches what the user specified, and unit otherwise
574        if let Some(output_type) = &entity.head.output_type {
575            let tvar = self.type_var_from_hir(output_type.loc(), output_type, &generic_list)?;
576
577            self.trace_stack.push(TraceStackEntry::Message(format!(
578                "Unifying with output type {}",
579                tvar.debug_resolve(self)
580            )));
581            self.unify(&TypedExpression::Id(entity.body.inner.id), &tvar, ctx)
582                .into_diagnostic_no_expected_source(
583                    &entity.body,
584                    |diag,
585                     Tm {
586                         g: got,
587                         e: expected,
588                     }| {
589                        let expected = expected.display(self);
590                        let got = got.display(self);
591                        // FIXME: Might want to check if there is no return value in the block
592                        // and in that case suggest adding it.
593                        diag.message(format!(
594                            "Output type mismatch. Expected {expected}, got {got}"
595                        ))
596                        .primary_label(format!("Found type {got}"))
597                        .secondary_label(output_type, format!("{expected} type specified here"))
598                    },
599                    self,
600                )?;
601        } else {
602            // No output type, so unify with the unit type.
603            TypedExpression::Id(entity.body.inner.id)
604                .unify_with(&self.add_type_var(TypeVar::unit(entity.head.name.loc())), self)
605                .commit(self, ctx)
606                .into_diagnostic_no_expected_source(entity.body.loc(), |diag, Tm{g: got, e: _expected}| {
607                    diag.message("Output type mismatch")
608                        .primary_label(format!("Found type {got}", got = got.display(self)))
609                        .note(format!(
610                            "The {} does not specify a return type.\nAdd a return type, or remove the return value.",
611                            entity.head.unit_kind.name()
612                        ))
613                }, self)?;
614        }
615
616        if let Some(PipelineState {
617            current_stage_depth,
618            pipeline_loc,
619            total_depth,
620        }) = self.pipeline_state.clone()
621        {
622            self.unify(&total_depth.inner, &current_stage_depth, ctx)
623                .into_diagnostic_no_expected_source(
624                    pipeline_loc,
625                    |diag, tm| {
626                        let (e, g) = tm.display_e_g(self);
627                        diag.message(format!("Pipeline depth mismatch. Expected {g} got {e}"))
628                            .primary_label(format!("Found {e} stages in this pipeline"))
629                    },
630                    self,
631                )?;
632        }
633
634        self.check_requirements(true, ctx)?;
635
636        // NOTE: We may accidentally leak a stage depth if this function returns early. However,
637        // since we only use stage depths in pipelines where we re-set it when we enter,
638        // accidentally leaving a depth in place should not trigger any *new* compiler bugs, but
639        // may help track something down
640        self.pipeline_state = None;
641
642        Ok(())
643    }
644
645    fn setup_pipeline_state(
646        &mut self,
647        unit_kind: &Loc<UnitKind>,
648        body_loc: &Loc<Expression>,
649        generic_list: &GenericListToken,
650        depth: &Loc<TypeExpression>,
651        depth_typeexpr_id: &ExprID,
652    ) -> Result<()> {
653        let depth_var = self.hir_type_expr_to_var(depth, generic_list)?;
654        self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth_var.clone());
655        self.pipeline_state = Some(PipelineState {
656            current_stage_depth: self.add_type_var(TypeVar::Known(
657                unit_kind.loc(),
658                KnownType::Integer(BigInt::zero()),
659                vec![],
660            )),
661            pipeline_loc: body_loc.loc(),
662            total_depth: depth_var.clone().at_loc(depth),
663        });
664        self.add_requirement(Requirement::PositivePipelineDepth {
665            depth: depth_var.at_loc(depth),
666        });
667        Ok(())
668    }
669
670    #[trace_typechecker]
671    #[tracing::instrument(level = "trace", skip_all, fields(%entity.name))]
672    pub fn visit_unit(&mut self, entity: &Loc<Unit>, ctx: &Context) -> Result<()> {
673        self.visit_unit_with_preprocessing(entity, |_, _, _, _| Ok(()), ctx)
674    }
675
676    #[trace_typechecker]
677    #[tracing::instrument(level = "trace", skip_all)]
678    fn visit_argument_list(
679        &mut self,
680        args: &Loc<ArgumentList<Expression>>,
681        ctx: &Context,
682        generic_list: &GenericListToken,
683    ) -> Result<()> {
684        for expr in args.expressions() {
685            self.visit_expression(expr, ctx, generic_list);
686        }
687        Ok(())
688    }
689
690    #[trace_typechecker]
691    fn type_check_argument_list(
692        &mut self,
693        args: &[Argument<Expression, TypeSpec>],
694        ctx: &Context,
695        generic_list: &GenericListToken,
696    ) -> Result<()> {
697        for Argument {
698            target,
699            target_type,
700            value,
701            kind,
702        } in args.iter()
703        {
704            let target_type = self.type_var_from_hir(value.loc(), target_type, generic_list)?;
705
706            let loc = match kind {
707                hir::param_util::ArgumentKind::Positional => value.loc(),
708                hir::param_util::ArgumentKind::Named => value.loc(),
709                hir::param_util::ArgumentKind::ShortNamed => target.loc(),
710            };
711
712            self.unify(&value.inner, &target_type, ctx)
713                .into_diagnostic(
714                    loc,
715                    |d, tm| {
716                        let (expected, got) = tm.display_e_g(self);
717                        d.message(format!(
718                            "Argument type mismatch. Expected {expected} got {got}"
719                        ))
720                        .primary_label(format!("expected {expected}"))
721                    },
722                    self,
723                )?;
724        }
725
726        Ok(())
727    }
728
729    #[trace_typechecker]
730    pub fn visit_expression_result(
731        &mut self,
732        expression: &Loc<Expression>,
733        ctx: &Context,
734        generic_list: &GenericListToken,
735        new_type: TypeVarID,
736    ) -> Result<()> {
737        // Recurse down the expression
738        match &expression.inner.kind {
739            ExprKind::Error => {
740                new_type
741                    .unify_with(&self.t_err(expression.loc()), self)
742                    .commit(self, ctx)
743                    .unwrap();
744            }
745            ExprKind::Identifier(_) => self.visit_identifier(expression, ctx)?,
746            ExprKind::TypeLevelInteger(_) => {
747                self.visit_type_level_integer(expression, generic_list, ctx)?
748            }
749            ExprKind::IntLiteral(_, _) => self.visit_int_literal(expression, ctx)?,
750            ExprKind::BoolLiteral(_) => self.visit_bool_literal(expression, ctx)?,
751            ExprKind::TriLiteral(_) => self.visit_tri_literal(expression, ctx)?,
752            ExprKind::TupleLiteral(_) => self.visit_tuple_literal(expression, ctx, generic_list)?,
753            ExprKind::TupleIndex(_, _) => self.visit_tuple_index(expression, ctx, generic_list)?,
754            ExprKind::ArrayLiteral(_) => self.visit_array_literal(expression, ctx, generic_list)?,
755            ExprKind::ArrayShorthandLiteral(_, _) => {
756                self.visit_array_shorthand_literal(expression, ctx, generic_list)?
757            }
758            ExprKind::CreatePorts => self.visit_create_ports(expression, ctx, generic_list)?,
759            ExprKind::FieldAccess(_, _) => {
760                self.visit_field_access(expression, ctx, generic_list)?
761            }
762            ExprKind::MethodCall { .. } => self.visit_method_call(expression, ctx, generic_list)?,
763            ExprKind::Index(_, _) => self.visit_index(expression, ctx, generic_list)?,
764            ExprKind::RangeIndex { .. } => self.visit_range_index(expression, ctx, generic_list)?,
765            ExprKind::Block(_) => self.visit_block_expr(expression, ctx, generic_list)?,
766            ExprKind::If(_, _, _) => self.visit_if(expression, ctx, generic_list)?,
767            ExprKind::Match(_, _) => self.visit_match(expression, ctx, generic_list)?,
768            ExprKind::BinaryOperator(_, _, _) => {
769                self.visit_binary_operator(expression, ctx, generic_list)?
770            }
771            ExprKind::UnaryOperator(_, _) => {
772                self.visit_unary_operator(expression, ctx, generic_list)?
773            }
774            ExprKind::Call {
775                kind,
776                callee,
777                args,
778                turbofish,
779                safety: _,
780            } => {
781                let head = ctx.symtab.unit_by_id(&callee.inner);
782
783                self.handle_function_like(
784                    expression.map_ref(|e| e.id),
785                    &expression.get_type(self),
786                    &FunctionLikeName::Free(callee.inner.clone()),
787                    &head,
788                    kind,
789                    args,
790                    ctx,
791                    true,
792                    false,
793                    turbofish.as_ref().map(|turbofish| TurbofishCtx {
794                        turbofish,
795                        prev_generic_list: generic_list,
796                        type_ctx: ctx,
797                    }),
798                    generic_list,
799                )?;
800            }
801            ExprKind::PipelineRef { .. } => {
802                self.visit_pipeline_ref(expression, generic_list, ctx)?;
803            }
804            ExprKind::StageReady | ExprKind::StageValid => {
805                expression
806                    .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
807                    .commit(self, ctx)
808                    .into_default_diagnostic(expression, self)?;
809            }
810
811            ExprKind::TypeLevelIf(cond, on_true, on_false) => {
812                let cond_var = self.visit_const_generic_with_id(
813                    cond,
814                    generic_list,
815                    ConstraintSource::TypeLevelIf,
816                    ctx,
817                )?;
818                let t_bool = self.new_generic_tlbool(cond.loc());
819                self.unify(&cond_var, &t_bool, ctx).into_diagnostic(
820                    cond,
821                    |diag, tm| {
822                        let (_e, g) = tm.display_e_g(self);
823                        diag.message(format!("gen if conditions must be #bool, got {g}"))
824                    },
825                    self,
826                )?;
827
828                self.visit_expression(on_true, ctx, generic_list);
829                self.visit_expression(on_false, ctx, generic_list);
830
831                self.unify_expression_generic_error(expression, on_true.as_ref(), ctx)?;
832                self.unify_expression_generic_error(expression, on_false.as_ref(), ctx)?;
833            }
834            ExprKind::LambdaDef {
835                unit_kind,
836                arguments,
837                body,
838                lambda_type,
839                type_params,
840                outer_generic_params,
841                lambda_unit: _,
842                clock,
843                captures,
844            } => {
845                for arg in arguments {
846                    self.visit_pattern(arg, ctx, generic_list)?;
847                }
848
849                if let Some(clock) = clock {
850                    clock
851                        .unify_with(&self.t_clock(clock.loc(), ctx.symtab), self)
852                        .commit(self, ctx)
853                        .into_default_diagnostic(clock, self)?;
854                }
855
856                let outer_pipeline_state = self.pipeline_state.take();
857                if let UnitKind::Pipeline {
858                    depth,
859                    depth_typeexpr_id,
860                } = &unit_kind.inner
861                {
862                    self.setup_pipeline_state(
863                        unit_kind,
864                        body,
865                        &generic_list,
866                        depth,
867                        depth_typeexpr_id,
868                    )?;
869                }
870                self.visit_expression(body, ctx, generic_list);
871                self.pipeline_state = outer_pipeline_state;
872
873                let lambda_params = arguments
874                    .iter()
875                    .map(|arg| arg.get_type(self))
876                    .chain(vec![body.get_type(self)])
877                    .chain(captures.iter().map(|(_, cap_name)| cap_name.get_type(self)))
878                    .chain(
879                        outer_generic_params
880                            .iter()
881                            .map(|cap| {
882                                let t = self
883                                    .get_generic_list(generic_list)
884                                    .ok_or_else(|| {
885                                        diag_anyhow!(
886                                            expression,
887                                            "Found a captured generic but no generic list"
888                                        )
889                                    })?
890                                    .get(&cap.name_in_body)
891                                    .ok_or_else(|| {
892                                        diag_anyhow!(
893                                            &cap.name_in_body,
894                                            "Did not find an entry for {} in lambda generic list",
895                                            cap.name_in_body
896                                        )
897                                    });
898                                Ok(t?.clone())
899                            })
900                            .collect::<Result<Vec<_>>>()?
901                            .into_iter(),
902                    )
903                    .collect::<Vec<_>>();
904
905                let self_type = TypeVar::Known(
906                    expression.loc(),
907                    KnownType::Named(lambda_type.clone()),
908                    lambda_params.clone(),
909                );
910
911                let unit_generic_list = self.create_generic_list(
912                    GenericListSource::Expression(expression.id),
913                    &type_params.all().cloned().collect::<Vec<_>>(),
914                    &[],
915                    None,
916                    &[],
917                )?;
918
919                for (p, tp) in lambda_params.iter().zip(type_params.all()) {
920                    let gl = self.get_generic_list(&unit_generic_list).unwrap();
921                    // Safe unwrap, we're just unifying unknowns with unknowns
922                    p.unify_with(
923                        gl.get(&tp.name_id).ok_or_else(|| {
924                            diag_anyhow!(
925                                expression,
926                                "Lambda unit list did not contain {}",
927                                tp.name_id
928                            )
929                        })?,
930                        self,
931                    )
932                    .commit(self, ctx)
933                    .into_default_diagnostic(expression, self)?;
934                }
935                expression
936                    .unify_with(&self.add_type_var(self_type), self)
937                    .commit(self, ctx)
938                    .into_default_diagnostic(expression, self)?;
939            }
940            ExprKind::StaticUnreachable(_) => {}
941            ExprKind::Null => {}
942        }
943        Ok(())
944    }
945
946    #[tracing::instrument(level = "trace", skip_all)]
947    pub fn visit_expression(
948        &mut self,
949        expression: &Loc<Expression>,
950        ctx: &Context,
951        generic_list: &GenericListToken,
952    ) {
953        let new_type = self.new_generic_type(expression.loc());
954        self.add_equation(TypedExpression::Id(expression.inner.id), new_type);
955
956        match self.visit_expression_result(expression, ctx, generic_list, new_type) {
957            Ok(_) => {}
958            Err(e) => {
959                new_type
960                    .unify_with(&self.t_err(expression.loc()), self)
961                    .commit(self, ctx)
962                    .unwrap();
963
964                self.diags.errors.push(e);
965            }
966        }
967    }
968
969    // Common handler for entities, functions and pipelines
970    #[tracing::instrument(level = "trace", skip_all, fields(%name))]
971    #[trace_typechecker]
972    fn handle_function_like(
973        &mut self,
974        expression_id: Loc<ExprID>,
975        expression_type: &TypeVarID,
976        name: &FunctionLikeName,
977        head: &Loc<UnitHead>,
978        call_kind: &CallKind,
979        args: &Loc<ArgumentList<Expression>>,
980        ctx: &Context,
981        // Whether or not to visit the argument expressions passed to the function here. This
982        // should not be done if the expressoins have been visited before, for example, when
983        // handling methods
984        visit_args: bool,
985        // If we are calling a method, we have an implicit self argument which means
986        // that any error reporting number of arguments should be reduced by one
987        is_method: bool,
988        turbofish: Option<TurbofishCtx>,
989        generic_list: &GenericListToken,
990    ) -> Result<()> {
991        // Add new symbols for all the type parameters
992        let unit_generic_list = self.create_generic_list(
993            GenericListSource::Expression(expression_id.inner),
994            &head.unit_type_params,
995            &head.scope_type_params,
996            turbofish,
997            &head.where_clauses,
998        )?;
999
1000        match (&head.unit_kind.inner, call_kind) {
1001            (
1002                UnitKind::Pipeline {
1003                    depth: udepth,
1004                    depth_typeexpr_id: _,
1005                },
1006                CallKind::Pipeline {
1007                    inst_loc: _,
1008                    depth: cdepth,
1009                    depth_typeexpr_id: cdepth_typeexpr_id,
1010                },
1011            ) => {
1012                let definition_depth = self.hir_type_expr_to_var(udepth, &unit_generic_list)?;
1013                let call_depth = self.hir_type_expr_to_var(cdepth, generic_list)?;
1014
1015                // NOTE: We're not adding udepth_typeexpr_id here as that would break
1016                // in the future if we try to do recursion. We will also never need to look
1017                // up the depth in the definition
1018                self.add_equation(TypedExpression::Id(*cdepth_typeexpr_id), call_depth.clone());
1019
1020                self.unify(&call_depth, &definition_depth, ctx)
1021                    .into_diagnostic_no_expected_source(
1022                        cdepth,
1023                        |diag, tm| {
1024                            let (e, g) = tm.display_e_g(self);
1025                            diag.message("Pipeline depth mismatch")
1026                                .primary_label(format!("Expected depth {e}, got {g}"))
1027                                .secondary_label(udepth, format!("{name} has depth {e}"))
1028                        },
1029                        self,
1030                    )?;
1031            }
1032            _ => {}
1033        }
1034
1035        if visit_args {
1036            self.visit_argument_list(args, ctx, &generic_list)?;
1037        }
1038
1039        let type_params = &head.get_type_params();
1040
1041        // Special handling of built in functions
1042        macro_rules! handle_special_functions {
1043            ($([$($path:expr),*] => $handler:expr),*) => {
1044                $(
1045                    let path = Path(vec![$(Identifier($path.to_string()).nowhere()),*]).nowhere();
1046                    if ctx.symtab
1047                        .try_lookup_final_id(&path, &[])
1048                        .map(|n| &FunctionLikeName::Free(n) == name)
1049                        .unwrap_or(false)
1050                    {
1051                        $handler
1052                    };
1053                )*
1054            }
1055        }
1056
1057        // NOTE: These would be better as a closure, but unfortunately that does not satisfy
1058        // the borrow checker
1059        macro_rules! generic_arg {
1060            ($idx:expr) => {
1061                self.get_generic_list(&unit_generic_list)
1062                    .ok_or_else(|| diag_anyhow!(expression_id, "Found no generic list for call"))?
1063                    [&type_params[$idx].name_id()]
1064                    .clone()
1065            };
1066        }
1067
1068        let matched_args =
1069            match_args_with_params(args, &head.inputs.inner, is_method).map_err(|e| {
1070                let diag: Diagnostic = e.into();
1071                diag.secondary_label(
1072                    head,
1073                    format!("{kind} defined here", kind = head.unit_kind.name()),
1074                )
1075            })?;
1076
1077        handle_special_functions! {
1078            ["std", "conv", "concat"] => {
1079                self.handle_concat(
1080                    expression_id,
1081                    generic_arg!(0),
1082                    generic_arg!(1),
1083                    generic_arg!(2),
1084                    &matched_args,
1085                    ctx
1086                )?
1087            },
1088            ["std", "conv", "trunc"] => {
1089                self.handle_trunc(
1090                    expression_id,
1091                    generic_arg!(0),
1092                    generic_arg!(1),
1093                    &matched_args,
1094                    ctx
1095                )?
1096            },
1097            ["std", "ops", "comb_div"] => {
1098                self.handle_comb_mod_or_div(
1099                    generic_arg!(0),
1100                    &matched_args,
1101                    ctx
1102                )?
1103            },
1104            ["std", "ops", "comb_mod"] => {
1105                self.handle_comb_mod_or_div(
1106                    generic_arg!(0),
1107                    &matched_args,
1108                    ctx
1109                )?
1110            },
1111            ["std", "mem", "clocked_memory"]  => {
1112                let num_elements = generic_arg!(0);
1113                let addr_size = generic_arg!(2);
1114
1115                self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1116            },
1117            // NOTE: the last argument of _init does not need special treatment, so
1118            // we can use the same code path for both for now
1119            ["std", "mem", "clocked_memory_init"]  => {
1120                let num_elements = generic_arg!(0);
1121                let addr_size = generic_arg!(2);
1122
1123                self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1124            },
1125            ["std", "mem", "read_memory"]  => {
1126                let addr_size = generic_arg!(0);
1127                let num_elements = generic_arg!(2);
1128
1129                self.handle_read_memory(num_elements, addr_size, &matched_args, ctx)?
1130            }
1131        };
1132
1133        // Unify the types of the arguments
1134        self.type_check_argument_list(&matched_args, ctx, &unit_generic_list)?;
1135
1136        let return_type = head
1137            .output_type
1138            .as_ref()
1139            .map(|o| self.type_var_from_hir(expression_id.loc(), o, &unit_generic_list))
1140            .transpose()?
1141            .unwrap_or_else(|| {
1142                self.add_type_var(TypeVar::Known(
1143                    expression_id.loc(),
1144                    KnownType::Tuple,
1145                    vec![],
1146                ))
1147            });
1148
1149        self.unify(expression_type, &return_type, ctx)
1150            .into_default_diagnostic(expression_id.loc(), self)?;
1151
1152        Ok(())
1153    }
1154
1155    pub fn handle_concat(
1156        &mut self,
1157        expression_id: Loc<ExprID>,
1158        source_lhs_ty: TypeVarID,
1159        source_rhs_ty: TypeVarID,
1160        source_result_ty: TypeVarID,
1161        args: &[Argument<Expression, TypeSpec>],
1162        ctx: &Context,
1163    ) -> Result<()> {
1164        let (lhs_type, lhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1165        let (rhs_type, rhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1166        let (result_type, result_size) = self.new_generic_number(expression_id.loc(), ctx);
1167        self.unify(&source_lhs_ty, &lhs_type, ctx)
1168            .into_default_diagnostic(args[0].value.loc(), self)?;
1169        self.unify(&source_rhs_ty, &rhs_type, ctx)
1170            .into_default_diagnostic(args[1].value.loc(), self)?;
1171        self.unify(&source_result_ty, &result_type, ctx)
1172            .into_default_diagnostic(expression_id.loc(), self)?;
1173
1174        // Result size is sum of input sizes
1175        self.add_constraint(
1176            result_size.clone(),
1177            ce_var(&lhs_size) + ce_var(&rhs_size),
1178            expression_id.loc(),
1179            &result_size,
1180            ConstraintSource::Concatenation,
1181        );
1182        self.add_constraint(
1183            lhs_size.clone(),
1184            ce_var(&result_size) + -ce_var(&rhs_size),
1185            args[0].value.loc(),
1186            &lhs_size,
1187            ConstraintSource::Concatenation,
1188        );
1189        self.add_constraint(
1190            rhs_size.clone(),
1191            ce_var(&result_size) + -ce_var(&lhs_size),
1192            args[1].value.loc(),
1193            &rhs_size,
1194            ConstraintSource::Concatenation,
1195        );
1196
1197        self.add_requirement(Requirement::SharedBase(vec![
1198            lhs_type.at_loc(args[0].value),
1199            rhs_type.at_loc(args[1].value),
1200            result_type.at_loc(&expression_id.loc()),
1201        ]));
1202        Ok(())
1203    }
1204
1205    pub fn handle_trunc(
1206        &mut self,
1207        expression_id: Loc<ExprID>,
1208        source_in_ty: TypeVarID,
1209        source_result_ty: TypeVarID,
1210        args: &[Argument<Expression, TypeSpec>],
1211        ctx: &Context,
1212    ) -> Result<()> {
1213        let (in_ty, _) = self.new_generic_number(expression_id.loc(), ctx);
1214        let (result_type, _) = self.new_generic_number(expression_id.loc(), ctx);
1215        self.unify(&source_in_ty, &in_ty, ctx)
1216            .into_default_diagnostic(args[0].value.loc(), self)?;
1217        self.unify(&source_result_ty, &result_type, ctx)
1218            .into_default_diagnostic(expression_id.loc(), self)?;
1219
1220        self.add_requirement(Requirement::SharedBase(vec![
1221            in_ty.at_loc(args[0].value),
1222            result_type.at_loc(&expression_id.loc()),
1223        ]));
1224        Ok(())
1225    }
1226
1227    pub fn handle_comb_mod_or_div(
1228        &mut self,
1229        n_ty: TypeVarID,
1230        args: &[Argument<Expression, TypeSpec>],
1231        ctx: &Context,
1232    ) -> Result<()> {
1233        let (num, _) = self.new_generic_number(args[0].value.loc(), ctx);
1234        self.unify(&n_ty, &num, ctx)
1235            .into_default_diagnostic(args[0].value.loc(), self)?;
1236        Ok(())
1237    }
1238
1239    pub fn handle_clocked_memory(
1240        &mut self,
1241        num_elements: TypeVarID,
1242        addr_size_arg: TypeVarID,
1243        args: &[Argument<Expression, TypeSpec>],
1244        ctx: &Context,
1245    ) -> Result<()> {
1246        // FIXME: When we support where clauses, we should move this
1247        // definition into the definition of clocked_memory
1248        let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1249        let arg1_loc = args[1].value.loc();
1250        let tup = TypeVar::tuple(
1251            args[1].value.loc(),
1252            vec![
1253                self.new_generic_type(arg1_loc),
1254                addr_type,
1255                self.new_generic_type(arg1_loc),
1256            ],
1257        );
1258        let port_type = TypeVar::array(
1259            arg1_loc,
1260            self.add_type_var(tup),
1261            self.new_generic_tluint(arg1_loc),
1262        )
1263        .insert(self);
1264
1265        self.add_constraint(
1266            addr_size.clone(),
1267            bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1268            args[1].value.loc(),
1269            &port_type,
1270            ConstraintSource::MemoryIndexing,
1271        );
1272
1273        // NOTE: Unwrap is safe, size is still generic at this point
1274        self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1275        self.unify_expression_generic_error(args[1].value, &port_type, ctx)?;
1276
1277        Ok(())
1278    }
1279
1280    pub fn handle_read_memory(
1281        &mut self,
1282        num_elements: TypeVarID,
1283        addr_size_arg: TypeVarID,
1284        args: &[Argument<Expression, TypeSpec>],
1285        ctx: &Context,
1286    ) -> Result<()> {
1287        let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1288
1289        self.add_constraint(
1290            addr_size.clone(),
1291            bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1292            args[1].value.loc(),
1293            &addr_type,
1294            ConstraintSource::MemoryIndexing,
1295        );
1296
1297        // NOTE: Unwrap is safe, size is still generic at this point
1298        self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1299
1300        Ok(())
1301    }
1302
1303    #[tracing::instrument(level = "trace", skip(self, turbofish, where_clauses))]
1304    pub fn create_generic_list(
1305        &mut self,
1306        source: GenericListSource,
1307        type_params: &[Loc<TypeParam>],
1308        scope_type_params: &[Loc<TypeParam>],
1309        turbofish: Option<TurbofishCtx>,
1310        where_clauses: &[Loc<WhereClause>],
1311    ) -> Result<GenericListToken> {
1312        let turbofish_params = if let Some(turbofish) = turbofish.as_ref() {
1313            if type_params.is_empty() {
1314                return Err(Diagnostic::error(
1315                    turbofish.turbofish,
1316                    "Turbofish on non-generic function",
1317                )
1318                .primary_label("Turbofish on non-generic function"));
1319            }
1320
1321            let matched_params =
1322                param_util::match_args_with_params(turbofish.turbofish, &type_params, false)?;
1323
1324            // We want this to be positional, but the params we get from matching are
1325            // named, transform it. We'll do some unwrapping here, but it is safe
1326            // because we know all params are present
1327            matched_params
1328                .iter()
1329                .map(|matched_param| {
1330                    let i = type_params
1331                        .iter()
1332                        .enumerate()
1333                        .find_map(|(i, param)| match &param.inner {
1334                            TypeParam {
1335                                ident,
1336                                name_id: _,
1337                                trait_bounds: _,
1338                                meta: _,
1339                            } => {
1340                                if ident == matched_param.target {
1341                                    Some(i)
1342                                } else {
1343                                    None
1344                                }
1345                            }
1346                        })
1347                        .unwrap();
1348                    (i, matched_param)
1349                })
1350                .sorted_by_key(|(i, _)| *i)
1351                .map(|(_, mp)| Some(mp.value))
1352                .collect::<Vec<_>>()
1353        } else {
1354            type_params.iter().map(|_| None).collect::<Vec<_>>()
1355        };
1356
1357        let mut inline_trait_bounds: Vec<Loc<WhereClause>> = vec![];
1358
1359        let scope_type_params = scope_type_params
1360            .iter()
1361            .map(|param| {
1362                let hir::TypeParam {
1363                    ident,
1364                    name_id,
1365                    trait_bounds,
1366                    meta,
1367                } = &param.inner;
1368                if !trait_bounds.is_empty() {
1369                    if let MetaType::Type = meta {
1370                        inline_trait_bounds.push(
1371                            WhereClause::Type {
1372                                target: name_id.clone().at_loc(ident),
1373                                traits: trait_bounds.clone(),
1374                            }
1375                            .at_loc(param),
1376                        );
1377                    } else {
1378                        return Err(Diagnostic::bug(param, "Trait bounds on generic int")
1379                            .primary_label("Trait bounds are only allowed on type parameters"));
1380                    }
1381                }
1382                Ok((
1383                    name_id.clone(),
1384                    self.new_generic_with_meta(param.loc(), meta.clone()),
1385                ))
1386            })
1387            .collect::<Result<Vec<_>>>()?;
1388
1389        let new_list = type_params
1390            .iter()
1391            .enumerate()
1392            .map(|(i, param)| {
1393                let hir::TypeParam {
1394                    ident,
1395                    name_id,
1396                    trait_bounds,
1397                    meta,
1398                } = &param.inner;
1399
1400                let t = self.new_generic_with_meta(param.loc(), meta.clone());
1401
1402                if let Some(tf) = &turbofish_params[i] {
1403                    let tf_ctx = turbofish.as_ref().unwrap();
1404                    let ty = self.hir_type_expr_to_var(tf, tf_ctx.prev_generic_list)?;
1405                    self.unify(&ty, &t, tf_ctx.type_ctx)
1406                        .into_default_diagnostic(param, self)?;
1407                }
1408
1409                if !trait_bounds.is_empty() {
1410                    if let MetaType::Type = meta {
1411                        inline_trait_bounds.push(
1412                            WhereClause::Type {
1413                                target: name_id.clone().at_loc(ident),
1414                                traits: trait_bounds.clone(),
1415                            }
1416                            .at_loc(param),
1417                        );
1418                    }
1419                    Ok((name_id.clone(), t))
1420                } else {
1421                    Ok((name_id.clone(), t))
1422                }
1423            })
1424            .collect::<Result<Vec<_>>>()?
1425            .into_iter()
1426            .chain(scope_type_params.into_iter())
1427            .map(|(name, t)| (name, t.clone()))
1428            .collect::<HashMap<_, _>>();
1429
1430        self.trace_stack.push(TraceStackEntry::NewGenericList(
1431            new_list
1432                .iter()
1433                .map(|(name, var)| (name.clone(), var.debug_resolve(self)))
1434                .collect(),
1435        ));
1436
1437        let token = self.add_mapped_generic_list(source, new_list.clone());
1438
1439        for constraint in where_clauses.iter().chain(inline_trait_bounds.iter()) {
1440            match &constraint.inner {
1441                WhereClause::Type { target, traits } => {
1442                    self.visit_trait_bounds(target, traits.as_slice(), &token)?;
1443                }
1444                WhereClause::Int { target, constraint } => {
1445                    let int_constraint = self.visit_const_generic(constraint, &token)?;
1446                    let tvar = new_list.get(target).ok_or_else(|| {
1447                        Diagnostic::error(
1448                            target,
1449                            format!("{target} is not a generic parameter on this unit"),
1450                        )
1451                        .primary_label("Not a generic parameter")
1452                    })?;
1453
1454                    self.add_constraint(
1455                        tvar.clone(),
1456                        int_constraint,
1457                        constraint.loc(),
1458                        &tvar,
1459                        ConstraintSource::Where,
1460                    );
1461                }
1462            }
1463        }
1464
1465        Ok(token)
1466    }
1467
1468    /// Adds a generic list with parameters already mapped to types
1469    pub fn add_mapped_generic_list(
1470        &mut self,
1471        source: GenericListSource,
1472        mapping: HashMap<NameID, TypeVarID>,
1473    ) -> GenericListToken {
1474        let reference = match source {
1475            GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1476            GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1477            GenericListSource::ImplBlock { target, id } => {
1478                GenericListToken::ImplBlock(target.clone(), id)
1479            }
1480            GenericListSource::Expression(id) => GenericListToken::Expression(id),
1481        };
1482
1483        if self
1484            .generic_lists
1485            .insert(reference.clone(), mapping)
1486            .is_some()
1487        {
1488            panic!("A generic list already existed for {reference:?}");
1489        }
1490        reference
1491    }
1492
1493    pub fn remove_generic_list(&mut self, source: GenericListSource) {
1494        let reference = match source {
1495            GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1496            GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1497            GenericListSource::ImplBlock { target, id } => {
1498                GenericListToken::ImplBlock(target.clone(), id)
1499            }
1500            GenericListSource::Expression(id) => GenericListToken::Expression(id),
1501        };
1502
1503        self.generic_lists.remove(&reference.clone());
1504    }
1505
1506    #[tracing::instrument(level = "trace", skip_all)]
1507    #[trace_typechecker]
1508    pub fn visit_block(
1509        &mut self,
1510        block: &Block,
1511        ctx: &Context,
1512        generic_list: &GenericListToken,
1513    ) -> Result<()> {
1514        for statement in &block.statements {
1515            self.visit_statement(statement, ctx, generic_list);
1516        }
1517        if let Some(result) = &block.result {
1518            self.visit_expression(result, ctx, generic_list);
1519        }
1520        Ok(())
1521    }
1522
1523    #[tracing::instrument(level = "trace", skip_all)]
1524    pub fn visit_impl_blocks(&mut self, item_list: &ItemList) -> TraitImplList {
1525        let mut trait_impls = TraitImplList::new();
1526        for (target, impls) in &item_list.impls.inner {
1527            for (trait_name, impls) in impls {
1528                for (target_args, impls) in impls {
1529                    for (trait_args, impl_block) in impls {
1530                        let result =
1531                            (|| {
1532                                let generic_list = self.create_generic_list(
1533                                    GenericListSource::ImplBlock {
1534                                        target,
1535                                        id: impl_block.id,
1536                                    },
1537                                    &[],
1538                                    impl_block.type_params.as_slice(),
1539                                    None,
1540                                    &[],
1541                                )?;
1542
1543                                let loc = trait_name
1544                                    .name_loc()
1545                                    .map(|n| ().at_loc(&n))
1546                                    .unwrap_or(().at_loc(&impl_block));
1547
1548                                let trait_type_params = trait_args
1549                                    .iter()
1550                                    .map(|param| {
1551                                        Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1552                                            &param.clone().at_loc(&loc),
1553                                            &generic_list,
1554                                        )?))
1555                                    })
1556                                    .collect::<Result<_>>()?;
1557
1558                                let target_type_params = target_args
1559                                    .iter()
1560                                    .map(|param| {
1561                                        Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1562                                            &param.clone().at_loc(&loc),
1563                                            &generic_list,
1564                                        )?))
1565                                    })
1566                                    .collect::<Result<_>>()?;
1567
1568                                trait_impls.inner.entry(target.clone()).or_default().push(
1569                                    TraitImpl {
1570                                        name: trait_name.clone(),
1571                                        target_type_params,
1572                                        trait_type_params,
1573
1574                                        impl_block: impl_block.inner.clone(),
1575                                    },
1576                                );
1577
1578                                Ok(())
1579                            })();
1580
1581                        match result {
1582                            Ok(()) => {}
1583                            Err(e) => self.diags.errors.push(e),
1584                        }
1585                    }
1586                }
1587            }
1588        }
1589
1590        trait_impls
1591    }
1592
1593    #[trace_typechecker]
1594    pub fn visit_pattern(
1595        &mut self,
1596        pattern: &Loc<Pattern>,
1597        ctx: &Context,
1598        generic_list: &GenericListToken,
1599    ) -> Result<()> {
1600        let new_type = self.new_generic_type(pattern.loc());
1601        self.add_equation(TypedExpression::Id(pattern.inner.id), new_type);
1602        match &pattern.inner.kind {
1603            hir::PatternKind::Integer(val) => {
1604                let (num_t, _) = &self.new_generic_number(pattern.loc(), ctx);
1605                self.add_requirement(Requirement::FitsIntLiteral {
1606                    value: ConstantInt::Literal(val.clone()),
1607                    target_type: num_t.clone().at_loc(pattern),
1608                });
1609                self.unify(pattern, num_t, ctx)
1610                    .expect("Failed to unify new_generic with int");
1611            }
1612            hir::PatternKind::Bool(_) => {
1613                pattern
1614                    .unify_with(&self.t_bool(pattern.loc(), ctx.symtab), self)
1615                    .commit(self, ctx)
1616                    .expect("Expected new_generic with boolean");
1617            }
1618            hir::PatternKind::Name { name, pre_declared } => {
1619                if !pre_declared {
1620                    self.add_equation(
1621                        TypedExpression::Name(name.clone().inner),
1622                        pattern.get_type(self),
1623                    );
1624                }
1625                self.unify(
1626                    &TypedExpression::Id(pattern.id),
1627                    &TypedExpression::Name(name.clone().inner),
1628                    ctx,
1629                )
1630                .into_default_diagnostic(name.loc(), self)?;
1631            }
1632            hir::PatternKind::Tuple(subpatterns) => {
1633                for pattern in subpatterns {
1634                    self.visit_pattern(pattern, ctx, generic_list)?;
1635                }
1636                let tuple_type = self.add_type_var(TypeVar::tuple(
1637                    pattern.loc(),
1638                    subpatterns
1639                        .iter()
1640                        .map(|pattern| {
1641                            let p_type = pattern.get_type(self);
1642                            Ok(p_type)
1643                        })
1644                        .collect::<Result<_>>()?,
1645                ));
1646
1647                self.unify(pattern, &tuple_type, ctx)
1648                    .expect("Unification of new_generic with tuple type cannot fail");
1649            }
1650            hir::PatternKind::Array(inner) => {
1651                for pattern in inner {
1652                    self.visit_pattern(pattern, ctx, generic_list)?;
1653                }
1654                if inner.len() == 0 {
1655                    return Err(
1656                        Diagnostic::error(pattern, "Empty array patterns are unsupported")
1657                            .primary_label("Empty array pattern"),
1658                    );
1659                } else {
1660                    let inner_t = inner[0].get_type(self);
1661
1662                    for pattern in inner.iter().skip(1) {
1663                        self.unify(pattern, &inner_t, ctx)
1664                            .into_default_diagnostic(pattern, self)?;
1665                    }
1666
1667                    pattern
1668                        .unify_with(
1669                            &TypeVar::Known(
1670                                pattern.loc(),
1671                                KnownType::Array,
1672                                vec![
1673                                    inner_t,
1674                                    self.add_type_var(TypeVar::Known(
1675                                        pattern.loc(),
1676                                        KnownType::Integer(inner.len().to_bigint()),
1677                                        vec![],
1678                                    )),
1679                                ],
1680                            )
1681                            .insert(self),
1682                            self,
1683                        )
1684                        .commit(self, ctx)
1685                        .into_default_diagnostic(pattern, self)?;
1686                }
1687            }
1688            hir::PatternKind::Type(name, args) => {
1689                let (condition_type, params, generic_list) =
1690                    match ctx.symtab.patternable_type_by_id(name).inner {
1691                        Patternable {
1692                            kind: PatternableKind::Enum,
1693                            params: _,
1694                        } => {
1695                            let enum_variant = ctx.symtab.enum_variant_by_id(name).inner;
1696                            let generic_list = self.create_generic_list(
1697                                GenericListSource::Anonymous,
1698                                &enum_variant.type_params,
1699                                &[],
1700                                None,
1701                                &[],
1702                            )?;
1703
1704                            let condition_type = self.type_var_from_hir(
1705                                pattern.loc(),
1706                                &enum_variant.output_type,
1707                                &generic_list,
1708                            )?;
1709
1710                            (condition_type, enum_variant.params, generic_list)
1711                        }
1712                        Patternable {
1713                            kind: PatternableKind::Struct,
1714                            params: _,
1715                        } => {
1716                            let s = ctx.symtab.struct_by_id(name).inner;
1717                            let generic_list = self.create_generic_list(
1718                                GenericListSource::Anonymous,
1719                                &s.type_params,
1720                                &[],
1721                                None,
1722                                &[],
1723                            )?;
1724
1725                            let condition_type =
1726                                self.type_var_from_hir(pattern.loc(), &s.self_type, &generic_list)?;
1727
1728                            (condition_type, s.params, generic_list)
1729                        }
1730                    };
1731
1732                self.unify(pattern, &condition_type, ctx)
1733                    .expect("Unification of new_generic with enum cannot fail");
1734
1735                for (
1736                    PatternArgument {
1737                        target,
1738                        value: pattern,
1739                        kind,
1740                    },
1741                    Parameter {
1742                        name: _,
1743                        ty: target_type,
1744                        no_mangle: _,
1745                        field_translator: _,
1746                    },
1747                ) in args.iter().zip(params.0.iter())
1748                {
1749                    self.visit_pattern(pattern, ctx, &generic_list)?;
1750                    let target_type =
1751                        self.type_var_from_hir(target_type.loc(), target_type, &generic_list)?;
1752
1753                    let loc = match kind {
1754                        hir::ArgumentKind::Positional => pattern.loc(),
1755                        hir::ArgumentKind::Named => pattern.loc(),
1756                        hir::ArgumentKind::ShortNamed => target.loc(),
1757                    };
1758
1759                    self.unify(pattern, &target_type, ctx).into_diagnostic(
1760                        loc,
1761                        |d, tm| {
1762                            let (expected, got) = tm.display_e_g(self);
1763                            d.message(format!(
1764                                "Argument type mismatch. Expected {expected} got {got}"
1765                            ))
1766                            .primary_label(format!("expected {expected}"))
1767                        },
1768                        self,
1769                    )?;
1770                }
1771            }
1772        }
1773        Ok(())
1774    }
1775
1776    #[trace_typechecker]
1777    pub fn visit_wal_trace(
1778        &mut self,
1779        trace: &Loc<WalTrace>,
1780        ctx: &Context,
1781        generic_list: &GenericListToken,
1782    ) -> Result<()> {
1783        let WalTrace { clk, rst } = &trace.inner;
1784        clk.as_ref()
1785            .map(|x| {
1786                self.visit_expression(x, ctx, generic_list);
1787                x.unify_with(&self.t_clock(trace.loc(), ctx.symtab), self)
1788                    .commit(self, ctx)
1789                    .into_default_diagnostic(x, self)
1790            })
1791            .transpose()?;
1792        rst.as_ref()
1793            .map(|x| {
1794                self.visit_expression(x, ctx, generic_list);
1795                x.unify_with(&self.t_bool(trace.loc(), ctx.symtab), self)
1796                    .commit(self, ctx)
1797                    .into_default_diagnostic(x, self)
1798            })
1799            .transpose()?;
1800        Ok(())
1801    }
1802
1803    #[trace_typechecker]
1804    pub fn visit_statement_error(
1805        &mut self,
1806        stmt: &Loc<Statement>,
1807        ctx: &Context,
1808        generic_list: &GenericListToken,
1809    ) -> Result<()> {
1810        match &stmt.inner {
1811            Statement::Error => {
1812                if let Some(current_stage_depth) =
1813                    self.pipeline_state.as_ref().map(|s| s.current_stage_depth)
1814                {
1815                    current_stage_depth
1816                        .unify_with(&self.t_err(stmt.loc()), self)
1817                        .commit(self, ctx)
1818                        .unwrap();
1819                }
1820                Ok(())
1821            }
1822            Statement::Binding(Binding {
1823                pattern,
1824                ty,
1825                value,
1826                wal_trace,
1827            }) => {
1828                trace!("Visiting `let {} = ..`", pattern.kind);
1829                self.visit_expression(value, ctx, generic_list);
1830
1831                self.visit_pattern(pattern, ctx, generic_list)
1832                    .handle_in(&mut self.diags);
1833
1834                self.unify(&TypedExpression::Id(pattern.id), value, ctx)
1835                    .into_diagnostic(
1836                        pattern.loc(),
1837                        error_pattern_type_mismatch(
1838                            ty.as_ref().map(|t| t.loc()).unwrap_or_else(|| value.loc()),
1839                            self,
1840                        ),
1841                        self,
1842                    )
1843                    .handle_in(&mut self.diags);
1844
1845                if let Some(t) = ty {
1846                    let tvar = self.type_var_from_hir(t.loc(), t, generic_list)?;
1847                    self.unify(&TypedExpression::Id(pattern.id), &tvar, ctx)
1848                        .into_default_diagnostic(value.loc(), self)
1849                        .handle_in(&mut self.diags);
1850                }
1851
1852                wal_trace
1853                    .as_ref()
1854                    .map(|wt| self.visit_wal_trace(wt, ctx, generic_list))
1855                    .transpose()
1856                    .handle_in(&mut self.diags);
1857
1858                Ok(())
1859            }
1860            Statement::Expression(expr) => {
1861                self.visit_expression(expr, ctx, generic_list);
1862                Ok(())
1863            }
1864            Statement::Register(reg) => self.visit_register(reg, ctx, generic_list),
1865            Statement::Declaration(names) => {
1866                for name in names {
1867                    let new_type = self.new_generic_type(name.loc());
1868                    self.add_equation(TypedExpression::Name(name.clone().inner), new_type);
1869                }
1870                Ok(())
1871            }
1872            Statement::PipelineRegMarker(extra) => {
1873                match extra {
1874                    Some(PipelineRegMarkerExtra::Condition(cond)) => {
1875                        self.visit_expression(cond, ctx, generic_list);
1876                        cond.unify_with(&self.t_bool(cond.loc(), ctx.symtab), self)
1877                            .commit(self, ctx)
1878                            .into_default_diagnostic(cond, self)?;
1879                    }
1880                    Some(PipelineRegMarkerExtra::Count {
1881                        count: _,
1882                        count_typeexpr_id: _,
1883                    }) => {}
1884                    None => {}
1885                }
1886
1887                let current_stage_depth = self
1888                    .pipeline_state
1889                    .clone()
1890                    .ok_or_else(|| {
1891                        diag_anyhow!(stmt, "Found a pipeline reg marker in a non-pipeline")
1892                    })?
1893                    .current_stage_depth;
1894
1895                let new_depth = self.new_generic_tlint(stmt.loc());
1896                let offset = match extra {
1897                    Some(PipelineRegMarkerExtra::Count {
1898                        count,
1899                        count_typeexpr_id,
1900                    }) => {
1901                        let var = self.hir_type_expr_to_var(count, generic_list)?;
1902                        self.add_equation(TypedExpression::Id(*count_typeexpr_id), var.clone());
1903                        var
1904                    }
1905                    Some(PipelineRegMarkerExtra::Condition(_)) | None => self.add_type_var(
1906                        TypeVar::Known(stmt.loc(), KnownType::Integer(1.to_bigint()), vec![]),
1907                    ),
1908                };
1909
1910                let total_depth = ConstraintExpr::Sum(
1911                    Box::new(ConstraintExpr::Var(offset)),
1912                    Box::new(ConstraintExpr::Var(current_stage_depth)),
1913                );
1914                self.pipeline_state
1915                    .as_mut()
1916                    .expect("Expected to have a pipeline state")
1917                    .current_stage_depth = new_depth.clone();
1918
1919                self.add_constraint(
1920                    new_depth.clone(),
1921                    total_depth,
1922                    stmt.loc(),
1923                    &new_depth,
1924                    ConstraintSource::PipelineRegCount {
1925                        reg: stmt.loc(),
1926                        total: self.get_pipeline_state(stmt)?.total_depth.loc(),
1927                    },
1928                );
1929
1930                Ok(())
1931            }
1932            Statement::Label(name) => {
1933                let key = TypedExpression::Name(name.inner.clone());
1934                let var = if !self.equations.contains_key(&key) {
1935                    let var = self.new_generic_tlint(name.loc());
1936                    self.trace_stack.push(TraceStackEntry::AddingPipelineLabel(
1937                        name.inner.clone(),
1938                        var.debug_resolve(self),
1939                    ));
1940                    self.add_equation(key.clone(), var.clone());
1941                    var
1942                } else {
1943                    let var = self.equations.get(&key).unwrap().clone();
1944                    self.trace_stack
1945                        .push(TraceStackEntry::RecoveringPipelineLabel(
1946                            name.inner.clone(),
1947                            var.debug_resolve(self),
1948                        ));
1949                    var
1950                };
1951                // Safe unwrap, unifying with a fresh var
1952                self.unify(
1953                    &var,
1954                    &self.get_pipeline_state(name)?.current_stage_depth.clone(),
1955                    ctx,
1956                )
1957                .unwrap();
1958                Ok(())
1959            }
1960            Statement::WalSuffixed { .. } => Ok(()),
1961            Statement::Assert(expr) => {
1962                self.visit_expression(expr, ctx, generic_list);
1963
1964                expr.unify_with(&self.t_bool(stmt.loc(), ctx.symtab), self)
1965                    .commit(self, ctx)
1966                    .into_default_diagnostic(expr, self)
1967                    .handle_in(&mut self.diags);
1968                Ok(())
1969            }
1970            Statement::Set { target, value } => {
1971                self.visit_expression(target, ctx, generic_list);
1972                self.visit_expression(value, ctx, generic_list);
1973
1974                let inner_type = self.new_generic_type(value.loc());
1975                let outer_type = TypeVar::inverted(stmt.loc(), inner_type.clone()).insert(self);
1976                self.unify_expression_generic_error(target, &outer_type, ctx)
1977                    .handle_in(&mut self.diags);
1978                self.unify_expression_generic_error(value, &inner_type, ctx)
1979                    .handle_in(&mut self.diags);
1980
1981                Ok(())
1982            }
1983        }
1984    }
1985
1986    pub fn visit_statement(
1987        &mut self,
1988        stmt: &Loc<Statement>,
1989        ctx: &Context,
1990        generic_list: &GenericListToken,
1991    ) {
1992        if let Err(e) = self.visit_statement_error(stmt, ctx, generic_list) {
1993            self.diags.errors.push(e);
1994        }
1995    }
1996
1997    #[trace_typechecker]
1998    pub fn visit_register(
1999        &mut self,
2000        reg: &Register,
2001        ctx: &Context,
2002        generic_list: &GenericListToken,
2003    ) -> Result<()> {
2004        self.visit_pattern(&reg.pattern, ctx, generic_list)?;
2005
2006        let type_spec_type = match &reg.value_type {
2007            Some(t) => Some(self.type_var_from_hir(t.loc(), t, generic_list)?.at_loc(t)),
2008            None => None,
2009        };
2010
2011        // We need to do this before visiting value, in case it constrains the
2012        // type of the identifiers in the pattern
2013        if let Some(tvar) = &type_spec_type {
2014            self.unify(&TypedExpression::Id(reg.pattern.id), tvar, ctx)
2015                .into_diagnostic_no_expected_source(
2016                    reg.pattern.loc(),
2017                    error_pattern_type_mismatch(tvar.loc(), self),
2018                    self,
2019                )?;
2020        }
2021
2022        self.visit_expression(&reg.clock, ctx, generic_list);
2023        self.visit_expression(&reg.value, ctx, generic_list);
2024
2025        if let Some(tvar) = &type_spec_type {
2026            self.unify(&reg.value, tvar, ctx)
2027                .into_default_diagnostic(reg.value.loc(), self)?;
2028        }
2029
2030        if let Some((rst_cond, rst_value)) = &reg.reset {
2031            self.visit_expression(rst_cond, ctx, generic_list);
2032            self.visit_expression(rst_value, ctx, generic_list);
2033            // Ensure cond is a boolean
2034            rst_cond
2035                .unify_with(&self.t_bool(rst_cond.loc(), ctx.symtab), self)
2036                .commit(self, ctx)
2037                .into_diagnostic(
2038                    rst_cond.loc(),
2039                    |diag,
2040                     Tm {
2041                         g: got,
2042                         e: _expected,
2043                     }| {
2044                        diag.message(format!(
2045                            "Register reset condition must be a bool, got {got}",
2046                            got = got.display(self)
2047                        ))
2048                        .primary_label("expected bool")
2049                    },
2050                    self,
2051                )?;
2052
2053            // Ensure the reset value has the same type as the register itself
2054            self.unify(&rst_value.inner, &reg.value.inner, ctx)
2055                .into_diagnostic(
2056                    rst_value.loc(),
2057                    |diag, tm| {
2058                        let (expected, got) = tm.display_e_g(self);
2059                        diag.message(format!(
2060                            "Register reset value mismatch. Expected {expected} got {got}"
2061                        ))
2062                        .primary_label(format!("expected {expected}"))
2063                        .secondary_label(&reg.pattern, format!("because this has type {expected}"))
2064                    },
2065                    self,
2066                )?;
2067        }
2068
2069        if let Some(initial) = &reg.initial {
2070            self.visit_expression(initial, ctx, generic_list);
2071
2072            self.unify(&initial.inner, &reg.value.inner, ctx)
2073                .into_diagnostic(
2074                    initial.loc(),
2075                    |diag, tm| {
2076                        let (expected, got) = tm.display_e_g(self);
2077                        diag.message(format!(
2078                            "Register initial value mismatch. Expected {expected} got {got}"
2079                        ))
2080                        .primary_label(format!("expected {expected}, got {got}"))
2081                        .secondary_label(&reg.pattern, format!("because this has type {got}"))
2082                    },
2083                    self,
2084                )?;
2085        }
2086
2087        reg.clock
2088            .unify_with(&self.t_clock(reg.clock.loc(), ctx.symtab), self)
2089            .commit(self, ctx)
2090            .into_diagnostic(
2091                reg.clock.loc(),
2092                |diag,
2093                 Tm {
2094                     g: got,
2095                     e: _expected,
2096                 }| {
2097                    diag.message(format!(
2098                        "Expected clock, got {got}",
2099                        got = got.display(self)
2100                    ))
2101                    .primary_label("expected clock")
2102                },
2103                self,
2104            )?;
2105
2106        self.unify(&TypedExpression::Id(reg.pattern.id), &reg.value, ctx)
2107            .into_diagnostic(
2108                reg.pattern.loc(),
2109                error_pattern_type_mismatch(reg.value.loc(), self),
2110                self,
2111            )?;
2112
2113        Ok(())
2114    }
2115
2116    #[trace_typechecker]
2117    pub fn visit_trait_spec(
2118        &mut self,
2119        trait_spec: &Loc<TraitSpec>,
2120        generic_list: &GenericListToken,
2121    ) -> Result<Loc<TraitReq>> {
2122        let type_params = if let Some(type_params) = &trait_spec.inner.type_params {
2123            type_params
2124                .inner
2125                .iter()
2126                .map(|te| self.hir_type_expr_to_var(te, generic_list))
2127                .collect::<Result<_>>()?
2128        } else {
2129            vec![]
2130        };
2131
2132        Ok(TraitReq {
2133            name: trait_spec.name.clone(),
2134            type_params,
2135        }
2136        .at_loc(trait_spec))
2137    }
2138
2139    #[trace_typechecker]
2140    pub fn visit_trait_bounds(
2141        &mut self,
2142        target: &Loc<NameID>,
2143        traits: &[Loc<TraitSpec>],
2144        generic_list_tok: &GenericListToken,
2145    ) -> Result<()> {
2146        let trait_reqs = traits
2147            .iter()
2148            .map(|spec| self.visit_trait_spec(spec, generic_list_tok))
2149            .collect::<Result<BTreeSet<_>>>()?
2150            .into_iter()
2151            .collect_vec();
2152
2153        if !trait_reqs.is_empty() {
2154            let trait_list = TraitList::from_vec(trait_reqs);
2155
2156            let generic_list = self.generic_lists.get(generic_list_tok).unwrap();
2157
2158            let Some(tvar) = generic_list.get(&target.inner) else {
2159                return Err(Diagnostic::bug(
2160                    target,
2161                    "Couldn't find generic from where clause in generic list",
2162                )
2163                .primary_label(format!(
2164                    "Generic type {} not found in generic list",
2165                    target.inner
2166                )));
2167            };
2168
2169            self.trace_stack.push(TraceStackEntry::AddingTraitBounds(
2170                tvar.debug_resolve(self),
2171                trait_list.clone(),
2172            ));
2173
2174            match tvar.resolve(self) {
2175                TypeVar::Known(_, _, _) => {
2176                    // NOTE: This branch is a no-op as it is only triggered when re-visiting
2177                    // units for typeinference during monomorpization, a process which replaces
2178                    // some unknown types with known counterparts.
2179                    // Ideally, we'd re-run ensure_impls here but that requires a context which
2180                    // isn't necessarily available here, so #yolo
2181                }
2182                TypeVar::Unknown(loc, id, old_trait_list, _meta_type) => {
2183                    let new_tvar = self.add_type_var(TypeVar::Unknown(
2184                        *loc,
2185                        *id,
2186                        old_trait_list.clone().extend(trait_list),
2187                        MetaType::Type,
2188                    ));
2189
2190                    trace!(
2191                        "Adding trait bound {} on type {}",
2192                        new_tvar.display_with_meta(true, self),
2193                        target.inner
2194                    );
2195
2196                    let generic_list = self.generic_lists.get_mut(generic_list_tok).unwrap();
2197                    generic_list.insert(target.inner.clone(), new_tvar);
2198                }
2199            }
2200        }
2201
2202        Ok(())
2203    }
2204
2205    pub fn visit_const_generic_with_id(
2206        &mut self,
2207        gen: &Loc<ConstGenericWithId>,
2208        generic_list_token: &GenericListToken,
2209        constraint_source: ConstraintSource,
2210        ctx: &Context,
2211    ) -> Result<TypeVarID> {
2212        let var = match &gen.inner.inner {
2213            ConstGeneric::Name(name) => {
2214                let ty = &ctx.symtab.type_symbol_by_id(&name);
2215                match &ty.inner {
2216                    TypeSymbol::Declared(_, _) => {
2217                        return Err(Diagnostic::error(
2218                            name,
2219                            "{type_decl_kind} cannot be used in a const generic expression",
2220                        )
2221                        .primary_label("Type in  const generic")
2222                        .secondary_label(ty, "{name} is declared here"))
2223                    }
2224                    TypeSymbol::GenericArg { .. } | TypeSymbol::GenericMeta(MetaType::Type) => {
2225                        return Err(Diagnostic::error(
2226                            name,
2227                            "Generic types cannot be used in const generic expressions",
2228                        )
2229                        .primary_label("Type in  const generic")
2230                        .secondary_label(ty, "{name} is declared here")
2231                        .span_suggest_insert_before(
2232                            "Consider making this a value",
2233                            ty.loc(),
2234                            "#uint ",
2235                        ))
2236                    }
2237                    TypeSymbol::GenericMeta(MetaType::Number) => {
2238                        self.new_generic_tlnumber(gen.loc())
2239                    }
2240                    TypeSymbol::GenericMeta(MetaType::Int) => self.new_generic_tlint(gen.loc()),
2241                    TypeSymbol::GenericMeta(MetaType::Uint) => self.new_generic_tluint(gen.loc()),
2242                    TypeSymbol::GenericMeta(MetaType::Bool) => self.new_generic_tlbool(gen.loc()),
2243                    TypeSymbol::GenericMeta(MetaType::Str) => self.new_generic_tlstr(gen.loc()),
2244                    TypeSymbol::GenericMeta(MetaType::Any) => {
2245                        diag_bail!(gen, "Found any meta type")
2246                    }
2247                    TypeSymbol::Alias(_) => {
2248                        return Err(Diagnostic::error(
2249                            gen,
2250                            "Aliases are not currently supported in const generics",
2251                        )
2252                        .secondary_label(ty, "Alias defined here"))
2253                    }
2254                }
2255            }
2256            ConstGeneric::Int(_)
2257            | ConstGeneric::Add(_, _)
2258            | ConstGeneric::Sub(_, _)
2259            | ConstGeneric::Mul(_, _)
2260            | ConstGeneric::Div(_, _)
2261            | ConstGeneric::Mod(_, _)
2262            | ConstGeneric::UintBitsToFit(_) => self.new_generic_tlnumber(gen.loc()),
2263            ConstGeneric::Str(_) => self.new_generic_tlstr(gen.loc()),
2264            ConstGeneric::Eq(_, _) | ConstGeneric::NotEq(_, _) => {
2265                self.new_generic_tlbool(gen.loc())
2266            }
2267        };
2268        let constraint = self.visit_const_generic(&gen.inner.inner, generic_list_token)?;
2269        self.add_equation(TypedExpression::Id(gen.id), var.clone());
2270        self.add_constraint(var.clone(), constraint, gen.loc(), &var, constraint_source);
2271        Ok(var)
2272    }
2273
2274    #[trace_typechecker]
2275    pub fn visit_const_generic(
2276        &self,
2277        constraint: &ConstGeneric,
2278        generic_list: &GenericListToken,
2279    ) -> Result<ConstraintExpr> {
2280        let wrap = |lhs,
2281                    rhs,
2282                    wrapper: fn(Box<ConstraintExpr>, Box<ConstraintExpr>) -> ConstraintExpr|
2283         -> Result<_> {
2284            Ok(wrapper(
2285                Box::new(self.visit_const_generic(lhs, generic_list)?),
2286                Box::new(self.visit_const_generic(rhs, generic_list)?),
2287            ))
2288        };
2289        let constraint = match constraint {
2290            ConstGeneric::Name(n) => {
2291                let var = self
2292                    .get_generic_list(generic_list)
2293                    .ok_or_else(|| diag_anyhow!(n, "Found no generic list"))?
2294                    .get(n)
2295                    .ok_or_else(|| {
2296                        Diagnostic::bug(n, "Found non-generic argument in where clause")
2297                    })?;
2298                ConstraintExpr::Var(*var)
2299            }
2300            ConstGeneric::Int(val) => ConstraintExpr::Integer(val.clone()),
2301            ConstGeneric::Str(val) => ConstraintExpr::String(val.clone()),
2302            ConstGeneric::Add(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Sum)?,
2303            ConstGeneric::Sub(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Difference)?,
2304            ConstGeneric::Mul(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Product)?,
2305            ConstGeneric::Div(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Div)?,
2306            ConstGeneric::Mod(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Mod)?,
2307            ConstGeneric::Eq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Eq)?,
2308            ConstGeneric::NotEq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::NotEq)?,
2309            ConstGeneric::UintBitsToFit(a) => ConstraintExpr::UintBitsToRepresent(Box::new(
2310                self.visit_const_generic(a, generic_list)?,
2311            )),
2312        };
2313        Ok(constraint)
2314    }
2315}
2316
2317// Private helper functions
2318impl TypeState {
2319    fn new_typeid(&self) -> u64 {
2320        let mut next = self.next_typeid.borrow_mut();
2321        let result = *next;
2322        *next += 1;
2323        result
2324    }
2325
2326    pub fn add_equation(&mut self, expression: TypedExpression, var: TypeVarID) {
2327        self.trace_stack.push(TraceStackEntry::AddingEquation(
2328            expression.clone(),
2329            var.debug_resolve(self),
2330        ));
2331        if let Some(prev) = self.equations.insert(expression.clone(), var.clone()) {
2332            let var = var.clone();
2333            let expr = expression.clone();
2334            println!("{}", format_trace_stack(self));
2335            panic!("Adding equation for {} == {} but a previous eq exists.\n\tIt was previously bound to {}", expr, var.debug_resolve(self), prev.debug_resolve(self))
2336        }
2337    }
2338
2339    fn add_constraint(
2340        &mut self,
2341        lhs: TypeVarID,
2342        rhs: ConstraintExpr,
2343        loc: Loc<()>,
2344        inside: &TypeVarID,
2345        source: ConstraintSource,
2346    ) {
2347        let replaces = lhs.clone();
2348        let rhs = rhs.with_context(&replaces, &inside, source).at_loc(&loc);
2349
2350        self.trace_stack.push(TraceStackEntry::AddingConstraint(
2351            lhs.debug_resolve(self),
2352            rhs.inner.clone(),
2353        ));
2354
2355        self.constraints.add_int_constraint(lhs, rhs);
2356    }
2357
2358    fn add_requirement(&mut self, requirement: Requirement) {
2359        self.trace_stack
2360            .push(TraceStackEntry::AddRequirement(requirement.clone()));
2361        self.requirements.push(requirement)
2362    }
2363
2364    /// Performs unification but does not update constraints. This is done to avoid
2365    /// updating constraints more often than necessary. Technically, constraints can
2366    /// be updated even less often, but `unify` is a pretty natural point to do so.
2367
2368    fn unify_inner(
2369        &mut self,
2370        e1: &impl HasType,
2371        e2: &impl HasType,
2372        ctx: &Context,
2373    ) -> std::result::Result<TypeVarID, UnificationError> {
2374        let v1 = e1.get_type(self);
2375        let v2 = e2.get_type(self);
2376
2377        trace!(
2378            "Unifying {} with {}",
2379            v1.debug_resolve(self),
2380            v2.debug_resolve(self)
2381        );
2382
2383        self.trace_stack.push(TraceStackEntry::TryingUnify(
2384            v1.debug_resolve(self),
2385            v2.debug_resolve(self),
2386        ));
2387
2388        macro_rules! err_producer {
2389            () => {{
2390                self.trace_stack
2391                    .push(TraceStackEntry::Message("Produced error".to_string()));
2392                UnificationError::Normal(Tm {
2393                    g: UnificationTrace::new(v1),
2394                    e: UnificationTrace::new(v2),
2395                })
2396            }};
2397        }
2398        macro_rules! meta_err_producer {
2399            () => {{
2400                self.trace_stack
2401                    .push(TraceStackEntry::Message("Produced error".to_string()));
2402                UnificationError::MetaMismatch(Tm {
2403                    g: UnificationTrace::new(v1),
2404                    e: UnificationTrace::new(v2),
2405                })
2406            }};
2407        }
2408
2409        macro_rules! unify_if {
2410            ($condition:expr, $new_type:expr, $replaced_type:expr) => {
2411                if $condition {
2412                    Ok(($new_type, $replaced_type))
2413                } else {
2414                    Err(err_producer!())
2415                }
2416            };
2417        }
2418
2419        let unify_params = |s: &mut Self,
2420                            p1: &[TypeVarID],
2421                            p2: &[TypeVarID]|
2422         -> std::result::Result<(), UnificationError> {
2423            if p1.len() != p2.len() {
2424                return Err({
2425                    s.trace_stack
2426                        .push(TraceStackEntry::Message("Produced error".to_string()));
2427                    UnificationError::Normal(Tm {
2428                        g: UnificationTrace::new(v1),
2429                        e: UnificationTrace::new(v2),
2430                    })
2431                });
2432            }
2433
2434            for (t1, t2) in p1.iter().zip(p2.iter()) {
2435                match s.unify_inner(t1, t2, ctx) {
2436                    Ok(result) => result,
2437                    Err(e) => {
2438                        s.trace_stack
2439                            .push(TraceStackEntry::Message("Adding context".to_string()));
2440                        return Err(e).add_context(v1.clone(), v2.clone());
2441                    }
2442                };
2443            }
2444            Ok(())
2445        };
2446
2447        // Figure out the most general type, and take note if we need to
2448        // do any replacement of the types in the rest of the state
2449        let result = match (
2450            &(v1, v1.resolve(self).clone()),
2451            &(v2, v2.resolve(self).clone()),
2452        ) {
2453            ((_, TypeVar::Known(_, KnownType::Error, _)), _) => Ok((v1, vec![v2])),
2454            (_, (_, TypeVar::Known(_, KnownType::Error, _))) => Ok((v2, vec![v1])),
2455            ((_, TypeVar::Known(_, t1, p1)), (_, TypeVar::Known(_, t2, p2))) => {
2456                match (t1, t2) {
2457                    (KnownType::Integer(val1), KnownType::Integer(val2)) => {
2458                        // NOTE: We get better error messages if we don't actually
2459                        // replace v1 with v2. Not entirely sure why, but since they already
2460                        // have the same known type, we're fine. The same applies
2461                        // to all known types
2462                        unify_if!(val1 == val2, v1, vec![])
2463                    }
2464                    (KnownType::String(val1), KnownType::String(val2)) => {
2465                        // Copied from the (Integer, Integer) case, its remark may also apply
2466                        unify_if!(val1 == val2, v1, vec![])
2467                    }
2468                    (KnownType::Named(n1), KnownType::Named(n2)) => {
2469                        match (
2470                            &ctx.symtab.type_symbol_by_id(n1).inner,
2471                            &ctx.symtab.type_symbol_by_id(n2).inner,
2472                        ) {
2473                            (TypeSymbol::Declared(_, _), TypeSymbol::Declared(_, _)) => {
2474                                if n1 != n2 {
2475                                    return Err(err_producer!());
2476                                }
2477
2478                                let new_ts1 = ctx.symtab.type_symbol_by_id(n1).inner;
2479                                let new_ts2 = ctx.symtab.type_symbol_by_id(n2).inner;
2480                                unify_params(self, &p1, &p2)?;
2481                                unify_if!(new_ts1 == new_ts2, v1, vec![])
2482                            }
2483                            (TypeSymbol::Declared(_, _), TypeSymbol::GenericArg { traits }) => {
2484                                if !traits.is_empty() {
2485                                    todo!("Implement trait unifictaion");
2486                                }
2487                                Ok((v1, vec![]))
2488                            }
2489                            (TypeSymbol::GenericArg { traits }, TypeSymbol::Declared(_, _)) => {
2490                                if !traits.is_empty() {
2491                                    todo!("Implement trait unifictaion");
2492                                }
2493                                Ok((v2, vec![]))
2494                            }
2495                            (
2496                                TypeSymbol::GenericArg { traits: ltraits },
2497                                TypeSymbol::GenericArg { traits: rtraits },
2498                            ) => {
2499                                if !ltraits.is_empty() || !rtraits.is_empty() {
2500                                    todo!("Implement trait unifictaion");
2501                                }
2502                                Ok((v1, vec![]))
2503                            }
2504                            (TypeSymbol::Declared(_, _), TypeSymbol::GenericMeta(_)) => todo!(),
2505                            (TypeSymbol::GenericArg { traits: _ }, TypeSymbol::GenericMeta(_)) => {
2506                                todo!()
2507                            }
2508                            (TypeSymbol::GenericMeta(_), TypeSymbol::Declared(_, _)) => todo!(),
2509                            (TypeSymbol::GenericMeta(_), TypeSymbol::GenericArg { traits: _ }) => {
2510                                todo!()
2511                            }
2512                            (TypeSymbol::Alias(_), _) | (_, TypeSymbol::Alias(_)) => {
2513                                return Err(UnificationError::Specific(Diagnostic::bug(
2514                                    ().nowhere(),
2515                                    "Encountered a raw type alias during unification",
2516                                )))
2517                            }
2518                            (TypeSymbol::GenericMeta(_), TypeSymbol::GenericMeta(_)) => todo!(),
2519                        }
2520                    }
2521                    (KnownType::Array, KnownType::Array)
2522                    | (KnownType::Tuple, KnownType::Tuple)
2523                    | (KnownType::Wire, KnownType::Wire)
2524                    | (KnownType::Inverted, KnownType::Inverted) => {
2525                        // Note, replacements should only occur when a variable goes from Unknown
2526                        // to Known, not when the base is the same. Replacements take care
2527                        // of parameters. Therefore, None is returned here
2528                        unify_params(self, &p1, &p2)?;
2529                        Ok((v1, vec![]))
2530                    }
2531                    (_, _) => Err(err_producer!()),
2532                }
2533            }
2534            // Unknown with unknown requires merging traits
2535            (
2536                (_, TypeVar::Unknown(loc1, _, traits1, meta1)),
2537                (_, TypeVar::Unknown(loc2, _, traits2, meta2)),
2538            ) => {
2539                let new_loc = if meta1.is_more_concrete_than(meta2) {
2540                    loc1
2541                } else {
2542                    loc2
2543                };
2544                let new_t = match unify_meta(meta1, meta2) {
2545                    Some(meta @ MetaType::Any) | Some(meta @ MetaType::Number) => {
2546                        if traits1.inner.is_empty() || traits2.inner.is_empty() {
2547                            return Err(UnificationError::Specific(diag_anyhow!(
2548                                new_loc,
2549                                "Inferred an any meta-type with traits"
2550                            )));
2551                        }
2552                        self.new_generic_with_meta(*loc1, meta)
2553                    }
2554                    Some(MetaType::Type) => {
2555                        let new_trait_names = traits1
2556                            .inner
2557                            .iter()
2558                            .chain(traits2.inner.iter())
2559                            .map(|t| t.name.clone())
2560                            .collect::<BTreeSet<_>>()
2561                            .into_iter()
2562                            .collect::<Vec<_>>();
2563
2564                        let new_traits = new_trait_names
2565                            .iter()
2566                            .map(
2567                                |name| match (traits1.get_trait(name), traits2.get_trait(name)) {
2568                                    (Some(req1), Some(req2)) => {
2569                                        let new_params = req1
2570                                            .inner
2571                                            .type_params
2572                                            .iter()
2573                                            .zip(req2.inner.type_params.iter())
2574                                            .map(|(p1, p2)| self.unify(p1, p2, ctx))
2575                                            .collect::<std::result::Result<_, UnificationError>>(
2576                                            )?;
2577
2578                                        Ok(TraitReq {
2579                                            name: name.clone(),
2580                                            type_params: new_params,
2581                                        }
2582                                        .at_loc(req1))
2583                                    }
2584                                    (Some(t), None) => Ok(t.clone()),
2585                                    (None, Some(t)) => Ok(t.clone()),
2586                                    (None, None) => panic!("Found a trait but neither side has it"),
2587                                },
2588                            )
2589                            .collect::<std::result::Result<Vec<_>, UnificationError>>()?;
2590
2591                        self.new_generic_with_traits(*new_loc, TraitList::from_vec(new_traits))
2592                    }
2593                    Some(MetaType::Int) => self.new_generic_tlint(*new_loc),
2594                    Some(MetaType::Uint) => self.new_generic_tluint(*new_loc),
2595                    Some(MetaType::Bool) => self.new_generic_tlbool(*new_loc),
2596                    Some(MetaType::Str) => self.new_generic_tlstr(*new_loc),
2597                    None => return Err(meta_err_producer!()),
2598                };
2599                Ok((new_t, vec![v1, v2]))
2600            }
2601            (
2602                (otherid, TypeVar::Known(loc, base, params)),
2603                (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2604            )
2605            | (
2606                (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2607                (otherid, TypeVar::Known(loc, base, params)),
2608            ) => {
2609                let trait_is_expected = match (&v1.resolve(self), &v2.resolve(self)) {
2610                    (TypeVar::Known(_, _, _), _) => true,
2611                    _ => false,
2612                };
2613
2614                let impls = self.ensure_impls(otherid, traits, trait_is_expected, ukloc, ctx)?;
2615
2616                self.trace_stack.push(TraceStackEntry::Message(
2617                    "Unifying trait_parameters".to_string(),
2618                ));
2619                let mut new_params = params.clone();
2620                for (trait_impl, trait_req) in impls {
2621                    let mut param_map = BTreeMap::new();
2622
2623                    for (l, r) in trait_req
2624                        .type_params
2625                        .iter()
2626                        .zip(trait_impl.trait_type_params)
2627                    {
2628                        let copy = r.make_copy_with_mapping(self, &mut param_map);
2629                        self.unify(l, &copy, ctx)?;
2630                    }
2631
2632                    new_params = trait_impl
2633                        .target_type_params
2634                        .iter()
2635                        .zip(new_params)
2636                        .map(|(l, r)| {
2637                            let copy = l.make_copy_with_mapping(self, &mut param_map);
2638                            self.unify(&copy, &r, ctx).add_context(*ukid, *otherid)
2639                        })
2640                        .collect::<std::result::Result<_, _>>()?
2641                }
2642
2643                match (base, meta) {
2644                    (KnownType::Error, _) => {
2645                        unreachable!()
2646                    }
2647                    // Any matches all types
2648                    (_, MetaType::Any)
2649                    // All types are types
2650                    | (KnownType::Named(_), MetaType::Type)
2651                    | (KnownType::Tuple, MetaType::Type)
2652                    | (KnownType::Array, MetaType::Type)
2653                    | (KnownType::Wire, MetaType::Type)
2654                    | (KnownType::Bool(_), MetaType::Bool)
2655                    | (KnownType::String(_), MetaType::Str)
2656                    | (KnownType::Inverted, MetaType::Type)
2657                    // Integers match ints and numbers
2658                    | (KnownType::Integer(_), MetaType::Int)
2659                    | (KnownType::Integer(_), MetaType::Number)
2660                    => {
2661                        let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2662
2663                        Ok((new, vec![otherid.clone(), ukid.clone()]))
2664                    },
2665                    // Integers match uints but only if the concrete integer is >= 0
2666                    (KnownType::Integer(val), MetaType::Uint)
2667                    => {
2668                        if val < &0.to_bigint() {
2669                            Err(meta_err_producer!())
2670                        } else {
2671                            let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2672
2673                            Ok((new, vec![otherid.clone(), ukid.clone()]))
2674                        }
2675                    },
2676
2677                    // Integer with type
2678                    (KnownType::Integer(_), MetaType::Type) => Err(meta_err_producer!()),
2679
2680                    // Bools only unify with any or bool
2681                    (_, MetaType::Bool) => Err(meta_err_producer!()),
2682                    (KnownType::Bool(_), _) => Err(meta_err_producer!()),
2683
2684                    // Strings only unify with any or str
2685                    (_, MetaType::Str) => Err(meta_err_producer!()),
2686                    (KnownType::String(_), _) => Err(meta_err_producer!()),
2687
2688                    // Type with integer
2689                    (KnownType::Named(_), MetaType::Int | MetaType::Number | MetaType::Uint)
2690                    | (KnownType::Tuple, MetaType::Int | MetaType::Uint | MetaType::Number)
2691                    | (KnownType::Array, MetaType::Int | MetaType::Uint | MetaType::Number)
2692                    | (KnownType::Wire, MetaType::Int | MetaType::Uint | MetaType::Number)
2693                    | (KnownType::Inverted, MetaType::Int | MetaType::Uint | MetaType::Number)
2694                    => Err(meta_err_producer!())
2695                }
2696            }
2697        };
2698
2699        let (new_type, replaced_types) = result?;
2700
2701        self.trace_stack.push(TraceStackEntry::Unified(
2702            v1.debug_resolve(self),
2703            v2.debug_resolve(self),
2704            new_type.debug_resolve(self),
2705            replaced_types
2706                .iter()
2707                .map(|v| v.debug_resolve(self))
2708                .collect(),
2709        ));
2710
2711        for replaced_type in &replaced_types {
2712            if v1.inner != v2.inner {
2713                let (from, to) = (replaced_type.get_type(self), new_type.get_type(self));
2714                self.replacements.insert(from, to);
2715                if let Err(rec) = self.check_type_for_recursion(to, &mut vec![]) {
2716                    let err_t = self.t_err(().nowhere());
2717                    self.replacements.insert(to, err_t);
2718                    return Err(UnificationError::RecursiveType(rec));
2719                }
2720            }
2721        }
2722
2723        Ok(new_type)
2724    }
2725
2726    pub fn can_unify(&mut self, e1: &impl HasType, e2: &impl HasType, ctx: &Context) -> bool {
2727        self.trace_stack
2728            .push(TraceStackEntry::Enter("Running can_unify".to_string()));
2729        let result = self.do_and_restore(|s| s.unify(e1, e2, ctx)).is_ok();
2730        self.trace_stack.push(TraceStackEntry::Exit);
2731        result
2732    }
2733
2734    #[tracing::instrument(level = "trace", skip_all)]
2735    pub fn unify(
2736        &mut self,
2737        e1: &impl HasType,
2738        e2: &impl HasType,
2739        ctx: &Context,
2740    ) -> std::result::Result<TypeVarID, UnificationError> {
2741        let new_type = self.unify_inner(e1, e2, ctx)?;
2742
2743        // With replacement done, some of our constraints may have been updated to provide
2744        // more type inference information. Try to do unification of those new constraints too
2745        loop {
2746            trace!("Updating constraints");
2747            // NOTE: Cloning here is kind of ugly
2748            let new_info;
2749            (self.constraints, new_info) = self
2750                .constraints
2751                .clone()
2752                .update_type_level_value_constraints(self);
2753
2754            if new_info.is_empty() {
2755                break;
2756            }
2757
2758            for constraint in new_info {
2759                trace!(
2760                    "Constraint replaces {} with {:?}",
2761                    constraint.inner.0.display(self),
2762                    constraint.inner.1
2763                );
2764
2765                let ((var, replacement), loc) = constraint.split_loc();
2766
2767                self.trace_stack
2768                    .push(TraceStackEntry::InferringFromConstraints(
2769                        var.debug_resolve(self),
2770                        replacement.val.clone(),
2771                    ));
2772
2773                // NOTE: safe unwrap. We already checked the constraint above
2774                let expected_type = self.add_type_var(TypeVar::Known(loc, replacement.val, vec![]));
2775                let result = self.unify_inner(&expected_type.clone().at_loc(&loc), &var, ctx);
2776                let is_meta_error = matches!(result, Err(UnificationError::MetaMismatch { .. }));
2777                match result {
2778                    Ok(_) => {}
2779                    Err(UnificationError::Normal(Tm { mut e, mut g }))
2780                    | Err(UnificationError::MetaMismatch(Tm { mut e, mut g })) => {
2781                        e.inside.replace(
2782                            replacement
2783                                .context
2784                                .inside
2785                                .replace_inside(var, e.failing, self),
2786                        );
2787                        g.inside.replace(
2788                            replacement
2789                                .context
2790                                .inside
2791                                .replace_inside(var, g.failing, self),
2792                        );
2793                        return Err(UnificationError::FromConstraints {
2794                            got: g,
2795                            expected: e,
2796                            source: replacement.context.source,
2797                            loc,
2798                            is_meta_error,
2799                        });
2800                    }
2801                    Err(
2802                        e @ UnificationError::FromConstraints { .. }
2803                        | e @ UnificationError::Specific { .. }
2804                        | e @ UnificationError::RecursiveType(_)
2805                        | e @ UnificationError::UnsatisfiedTraits { .. },
2806                    ) => return Err(e),
2807                };
2808            }
2809        }
2810
2811        Ok(new_type)
2812    }
2813
2814    fn check_type_for_recursion(
2815        &self,
2816        ty: TypeVarID,
2817        seen: &mut Vec<TypeVarID>,
2818    ) -> std::result::Result<(), String> {
2819        seen.push(ty);
2820        match ty.resolve(self) {
2821            TypeVar::Known(_, base, params) => {
2822                for (i, param) in params.iter().enumerate() {
2823                    if seen.contains(param) {
2824                        return Err("*".to_string());
2825                    }
2826
2827                    if let Err(rest) = self.check_type_for_recursion(*param, seen) {
2828                        let list = params
2829                            .iter()
2830                            .enumerate()
2831                            .map(|(j, _)| {
2832                                if j == i {
2833                                    rest.clone()
2834                                } else {
2835                                    "_".to_string()
2836                                }
2837                            })
2838                            .join(", ");
2839
2840                        match base {
2841                            KnownType::Error => {}
2842                            KnownType::Named(name_id) => {
2843                                return Err(format!("{name_id}<{}>", list));
2844                            }
2845                            KnownType::Bool(_) | KnownType::Integer(_) | KnownType::String(_) => {
2846                                unreachable!("Encountered recursive type level bool, int or str")
2847                            }
2848                            KnownType::Tuple => return Err(format!("({})", list)),
2849                            KnownType::Array => return Err(format!("[{}]", list)),
2850                            KnownType::Wire => return Err(format!("&{}", list)),
2851                            KnownType::Inverted => return Err(format!("inv {}", list)),
2852                        }
2853                    }
2854                }
2855            }
2856            TypeVar::Unknown(_, _, traits, _) => {
2857                for t in &traits.inner {
2858                    for param in &t.type_params {
2859                        if seen.contains(param) {
2860                            return Err("...".to_string());
2861                        }
2862
2863                        self.check_type_for_recursion(*param, seen)?;
2864                    }
2865                }
2866            }
2867        }
2868        seen.pop();
2869        Ok(())
2870    }
2871
2872    fn ensure_impls(
2873        &mut self,
2874        var: &TypeVarID,
2875        traits: &TraitList,
2876        trait_is_expected: bool,
2877        trait_list_loc: &Loc<()>,
2878        ctx: &Context,
2879    ) -> std::result::Result<Vec<(TraitImpl, TraitReq)>, UnificationError> {
2880        self.trace_stack.push(TraceStackEntry::EnsuringImpls(
2881            var.debug_resolve(self),
2882            traits.clone(),
2883            trait_is_expected,
2884        ));
2885
2886        let number = ctx
2887            .symtab
2888            .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
2889            .expect("Did not find number in symtab")
2890            .0;
2891
2892        macro_rules! error_producer {
2893            ($required_traits:expr) => {
2894                if trait_is_expected {
2895                    if $required_traits.inner.len() == 1
2896                        && $required_traits
2897                            .get_trait(&TraitName::Named(number.clone().nowhere()))
2898                            .is_some()
2899                    {
2900                        Err(UnificationError::Normal(Tm {
2901                            e: UnificationTrace::new(
2902                                self.new_generic_with_traits(*trait_list_loc, $required_traits),
2903                            ),
2904                            g: UnificationTrace::new(var.clone()),
2905                        }))
2906                    } else {
2907                        Err(UnificationError::UnsatisfiedTraits {
2908                            var: *var,
2909                            traits: $required_traits.inner,
2910                            target_loc: trait_list_loc.clone(),
2911                        })
2912                    }
2913                } else {
2914                    Err(UnificationError::Normal(Tm {
2915                        e: UnificationTrace::new(var.clone()),
2916                        g: UnificationTrace::new(
2917                            self.new_generic_with_traits(*trait_list_loc, $required_traits),
2918                        ),
2919                    }))
2920                }
2921            };
2922        }
2923
2924        match &var.resolve(self).clone() {
2925            TypeVar::Known(_, known, params) if known.into_impl_target().is_some() => {
2926                let Some(target) = known.into_impl_target() else {
2927                    unreachable!()
2928                };
2929
2930                let (impls, unsatisfied): (Vec<_>, Vec<_>) = traits
2931                    .inner
2932                    .iter()
2933                    .map(|trait_req| {
2934                        if let Some(impld) = self.trait_impls.inner.get(&target).cloned() {
2935                            // Get a list of implementations of this trait where the type
2936                            // parameters can match
2937                            let target_impls = impld
2938                                .iter()
2939                                .filter_map(|trait_impl| {
2940                                    self.checkpoint();
2941                                    let trait_params_match = trait_impl
2942                                        .trait_type_params
2943                                        .iter()
2944                                        .zip(trait_req.type_params.iter())
2945                                        .all(|(l, r)| {
2946                                            let l = l.make_copy(self);
2947                                            self.unify(&l, r, ctx).is_ok()
2948                                        });
2949
2950                                    let impl_params_match =
2951                                        trait_impl.target_type_params.iter().zip(params).all(
2952                                            |(l, r)| {
2953                                                let l = l.make_copy(self);
2954                                                self.unify(&l, r, ctx).is_ok()
2955                                            },
2956                                        );
2957                                    self.restore();
2958
2959                                    if trait_impl.name == trait_req.name
2960                                        && trait_params_match
2961                                        && impl_params_match
2962                                    {
2963                                        Some(trait_impl)
2964                                    } else {
2965                                        None
2966                                    }
2967                                })
2968                                .collect::<Vec<_>>();
2969
2970                            if target_impls.len() == 0 {
2971                                Ok(Either::Right(trait_req.clone()))
2972                            } else if target_impls.len() == 1 {
2973                                let target_impl = *target_impls.last().unwrap();
2974                                Ok(Either::Left((target_impl.clone(), trait_req.inner.clone())))
2975                            } else {
2976                                Err(UnificationError::Specific(diag_anyhow!(
2977                                    trait_req,
2978                                    "Found more than one impl of {} for {}",
2979                                    trait_req.display(self),
2980                                    var.display(self)
2981                                )))
2982                            }
2983                        } else {
2984                            Ok(Either::Right(trait_req.clone()))
2985                        }
2986                    })
2987                    .collect::<std::result::Result<Vec<_>, _>>()?
2988                    .into_iter()
2989                    .partition_map(|x| x);
2990
2991                if unsatisfied.is_empty() {
2992                    self.trace_stack.push(TraceStackEntry::Message(
2993                        "Ensuring impl successful".to_string(),
2994                    ));
2995                    Ok(impls)
2996                } else {
2997                    error_producer!(TraitList::from_vec(unsatisfied.clone()))
2998                }
2999            }
3000            TypeVar::Unknown(_, _, _, _) => {
3001                panic!("running ensure_impls on an unknown type")
3002            }
3003            _ => {
3004                if traits.inner.is_empty() {
3005                    Ok(vec![])
3006                } else {
3007                    error_producer!(traits.clone())
3008                }
3009            }
3010        }
3011    }
3012
3013    pub fn unify_expression_generic_error(
3014        &mut self,
3015        expr: &Loc<Expression>,
3016        other: &impl HasType,
3017        ctx: &Context,
3018    ) -> Result<TypeVarID> {
3019        self.unify(&expr.inner, other, ctx)
3020            .into_default_diagnostic(expr.loc(), self)
3021    }
3022
3023    pub fn check_requirements(&mut self, is_final_check: bool, ctx: &Context) -> Result<()> {
3024        // Once we are done type checking the rest of the entity, check all requirements
3025        loop {
3026            // Walk through all the requirements, checking each one. If the requirement
3027            // is still undetermined, take note to retain that id, otherwise store the
3028            // replacement to be performed
3029            let (retain, replacements_option): (Vec<_>, Vec<_>) = self
3030                .requirements
3031                .clone()
3032                .iter()
3033                .map(|req| match req.check(self, ctx)? {
3034                    requirements::RequirementResult::NoChange => Ok((true, None)),
3035                    requirements::RequirementResult::UnsatisfiedNow(diag) => {
3036                        if is_final_check {
3037                            Err(diag)
3038                        } else {
3039                            Ok((true, None))
3040                        }
3041                    }
3042                    requirements::RequirementResult::Satisfied(replacement) => {
3043                        self.trace_stack
3044                            .push(TraceStackEntry::ResolvedRequirement(req.clone()));
3045                        Ok((false, Some(replacement)))
3046                    }
3047                })
3048                .collect::<Result<Vec<_>>>()?
3049                .into_iter()
3050                .unzip();
3051
3052            let replacements = replacements_option
3053                .into_iter()
3054                .flatten()
3055                .flatten()
3056                .collect::<Vec<_>>();
3057
3058            // Drop all replacements that have now been applied
3059            self.requirements = self
3060                .requirements
3061                .drain(0..)
3062                .zip(retain)
3063                .filter_map(|(req, keep)| if keep { Some(req) } else { None })
3064                .collect();
3065
3066            if replacements.is_empty() {
3067                break;
3068            }
3069
3070            for Replacement { from, to, context } in replacements {
3071                self.unify(&to, &from, ctx).into_diagnostic_or_default(
3072                    from.loc(),
3073                    context,
3074                    self,
3075                )?;
3076            }
3077        }
3078
3079        Ok(())
3080    }
3081
3082    pub fn get_replacement(&self, var: &TypeVarID) -> TypeVarID {
3083        self.replacements.get(*var)
3084    }
3085
3086    pub fn do_and_restore<T, E>(
3087        &mut self,
3088        inner: impl Fn(&mut Self) -> std::result::Result<T, E>,
3089    ) -> std::result::Result<T, E> {
3090        self.checkpoint();
3091        let result = inner(self);
3092        self.restore();
3093        result
3094    }
3095
3096    /// Create a "checkpoint" to which we can restore later using `restore`. This acts
3097    /// like a stack, nested checkpoints are possible
3098    fn checkpoint(&mut self) {
3099        self.trace_stack
3100            .push(TraceStackEntry::Enter("Creating checkpoint".to_string()));
3101        self.replacements.push();
3102
3103        // This is relatively expensive if these lists are large, however, for now
3104        // this is a much simpler solution than attempting to roll-back replaced requirements
3105        // later
3106        self.checkpoints
3107            .push((self.requirements.clone(), self.constraints.clone()));
3108    }
3109
3110    fn restore(&mut self) {
3111        self.replacements.discard_top();
3112        self.trace_stack.push(TraceStackEntry::Exit);
3113
3114        let (requirements, constraints) = self
3115            .checkpoints
3116            .pop()
3117            .expect("Popped a checkpoint without any existing checkpoints.");
3118
3119        self.requirements = requirements;
3120        self.constraints = constraints;
3121    }
3122}
3123
3124impl TypeState {
3125    pub fn print_equations(&self) {
3126        for (lhs, rhs) in &self.equations {
3127            println!(
3128                "{} -> {}",
3129                format!("{lhs}").blue(),
3130                format!("{}", rhs.debug_resolve(self)).green()
3131            )
3132        }
3133
3134        println!("\nReplacments:");
3135
3136        for repl_stack in &self.replacements.all() {
3137            let replacements = { repl_stack.borrow().clone() };
3138            for (lhs, rhs) in replacements.iter().sorted() {
3139                println!(
3140                    "{} -> {} ({} -> {})",
3141                    format!("{}", lhs.inner).blue(),
3142                    format!("{}", rhs.inner).green(),
3143                    format!("{}", lhs.debug_resolve(self)).blue(),
3144                    format!("{}", rhs.debug_resolve(self)).green(),
3145                )
3146            }
3147            println!("---")
3148        }
3149
3150        println!("\n Requirements:");
3151
3152        for requirement in &self.requirements {
3153            println!("{:?}", requirement)
3154        }
3155
3156        println!()
3157    }
3158}
3159
3160#[must_use]
3161pub struct UnificationBuilder {
3162    lhs: TypeVarID,
3163    rhs: TypeVarID,
3164}
3165impl UnificationBuilder {
3166    pub fn commit(
3167        self,
3168        state: &mut TypeState,
3169        ctx: &Context,
3170    ) -> std::result::Result<TypeVarID, UnificationError> {
3171        state.unify(&self.lhs, &self.rhs, ctx)
3172    }
3173}
3174
3175pub trait HasType: std::fmt::Debug {
3176    fn get_type(&self, state: &TypeState) -> TypeVarID {
3177        self.try_get_type(state)
3178            .unwrap_or(state.error_type.unwrap())
3179    }
3180
3181    fn try_get_type(&self, state: &TypeState) -> Option<TypeVarID> {
3182        let id = self.get_type_impl(state);
3183        id.map(|id| state.get_replacement(&id))
3184    }
3185
3186    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID>;
3187
3188    fn unify_with(&self, rhs: &dyn HasType, state: &TypeState) -> UnificationBuilder {
3189        UnificationBuilder {
3190            lhs: self.get_type(state),
3191            rhs: rhs.get_type(state),
3192        }
3193    }
3194}
3195
3196impl HasType for TypeVarID {
3197    fn get_type_impl(&self, _state: &TypeState) -> Option<TypeVarID> {
3198        Some(*self)
3199    }
3200}
3201impl HasType for Loc<TypeVarID> {
3202    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3203        self.inner.try_get_type(state)
3204    }
3205}
3206impl HasType for TypedExpression {
3207    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3208        state.maybe_type_of(self).cloned()
3209    }
3210}
3211impl HasType for Expression {
3212    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3213        state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3214    }
3215}
3216impl HasType for Loc<Expression> {
3217    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3218        state
3219            .maybe_type_of(&TypedExpression::Id(self.inner.id))
3220            .cloned()
3221    }
3222}
3223impl HasType for Pattern {
3224    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3225        state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3226    }
3227}
3228impl HasType for Loc<Pattern> {
3229    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3230        state
3231            .maybe_type_of(&TypedExpression::Id(self.inner.id))
3232            .cloned()
3233    }
3234}
3235impl HasType for NameID {
3236    fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3237        state
3238            .maybe_type_of(&TypedExpression::Name(self.clone()))
3239            .cloned()
3240    }
3241}