1use std::cell::RefCell;
10use std::cmp::PartialEq;
11use std::collections::{BTreeMap, BTreeSet, HashMap};
12
13use colored::Colorize;
14use fixed_types::{t_int, t_uint};
15use hir::expression::CallKind;
16use hir::{
17 param_util, Binding, ConstGeneric, Parameter, PipelineRegMarkerExtra, TypeExpression, TypeSpec,
18 UnitHead, UnitKind, WalTrace, WhereClause,
19};
20use itertools::{Either, Itertools};
21use method_resolution::{FunctionLikeName, IntoImplTarget};
22use num::{BigInt, BigUint, Zero};
23use replacement::ReplacementStack;
24use serde::{Deserialize, Serialize};
25use spade_common::id_tracker::{ExprID, ImplID};
26use spade_common::num_ext::InfallibleToBigInt;
27use spade_diagnostics::diag_list::{DiagList, ResultExt};
28use spade_diagnostics::{diag_anyhow, diag_bail, Diagnostic};
29use spade_macros::trace_typechecker;
30use spade_types::meta_types::{unify_meta, MetaType};
31use trace_stack::TraceStack;
32use tracing::{info, trace};
33
34use spade_common::location_info::{Loc, WithLocation};
35use spade_common::name::{Identifier, NameID, Path};
36use spade_hir::param_util::{match_args_with_params, Argument};
37use spade_hir::symbol_table::{Patternable, PatternableKind, SymbolTable, TypeSymbol};
38use spade_hir::{self as hir, ConstGenericWithId, ImplTarget};
39use spade_hir::{
40 ArgumentList, Block, ExprKind, Expression, ItemList, Pattern, PatternArgument, Register,
41 Statement, TraitName, TraitSpec, TypeParam, Unit,
42};
43use spade_types::KnownType;
44
45use constraints::{
46 bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource, TypeConstraints,
47};
48use equation::{TemplateTypeVarID, TypeEquations, TypeVar, TypeVarID, TypedExpression};
49use error::{
50 error_pattern_type_mismatch, Result, UnificationError, UnificationErrorExt, UnificationTrace,
51};
52use requirements::{Replacement, Requirement};
53use trace_stack::{format_trace_stack, TraceStackEntry};
54use traits::{TraitList, TraitReq};
55
56use crate::error::TypeMismatch as Tm;
57use crate::requirements::ConstantInt;
58use crate::traits::{TraitImpl, TraitImplList};
59
60mod constraints;
61pub mod dump;
62pub mod equation;
63pub mod error;
64pub mod expression;
65pub mod fixed_types;
66pub mod method_resolution;
67pub mod mir_type_lowering;
68mod replacement;
69mod requirements;
70pub mod testutil;
71pub mod trace_stack;
72pub mod traits;
73
74pub struct Context<'a> {
75 pub symtab: &'a SymbolTable,
76 pub items: &'a ItemList,
77 pub trait_impls: &'a TraitImplList,
78}
79impl<'a> std::fmt::Debug for Context<'a> {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "{{context omitted}}")
82 }
83}
84
85#[allow(unused_macros)]
87macro_rules! add_trace {
88 ($self:expr, $($arg : tt) *) => {
89 $self.trace_stack.push(TraceStack::Message(format!($($arg)*)))
90 }
91}
92
93#[derive(Debug)]
94pub enum GenericListSource<'a> {
95 Anonymous,
98 Definition(&'a NameID),
101 ImplBlock {
102 target: &'a ImplTarget,
103 id: ImplID,
104 },
105 Expression(ExprID),
107}
108
109#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize)]
111pub enum GenericListToken {
112 Anonymous(usize),
113 Definition(NameID),
114 ImplBlock(ImplTarget, ImplID),
115 Expression(ExprID),
116}
117
118#[derive(Debug)]
119pub struct TurbofishCtx<'a> {
120 turbofish: &'a Loc<ArgumentList<TypeExpression>>,
121 prev_generic_list: &'a GenericListToken,
122 type_ctx: &'a Context<'a>,
123}
124
125#[derive(Clone, Serialize, Deserialize)]
126pub struct PipelineState {
127 current_stage_depth: TypeVarID,
128 total_depth: Loc<TypeVarID>,
129 pipeline_loc: Loc<()>,
130}
131
132#[derive(Clone, Serialize, Deserialize)]
134pub struct TypeState {
135 type_vars: Vec<TypeVar>,
140 key: u64,
145 keys: BTreeSet<u64>,
147
148 equations: TypeEquations,
149 next_typeid: RefCell<u64>,
152 generic_lists: HashMap<GenericListToken, HashMap<NameID, TypeVarID>>,
158
159 constraints: TypeConstraints,
160
161 #[serde(skip)]
165 requirements: Vec<Requirement>,
166
167 replacements: ReplacementStack,
168 #[serde(skip)]
171 checkpoints: Vec<(Vec<Requirement>, TypeConstraints)>,
172
173 pipeline_state: Option<PipelineState>,
175
176 error_type: Option<TypeVarID>,
179
180 pub trait_impls: TraitImplList,
181
182 #[serde(skip)]
183 pub trace_stack: TraceStack,
184
185 #[serde(skip)]
186 pub diags: DiagList,
187}
188
189impl TypeState {
190 pub fn fresh() -> Self {
193 let key = fastrand::u64(..);
194 let mut result = Self {
195 type_vars: vec![],
196 key,
197 keys: [key].into_iter().collect(),
198 equations: HashMap::new(),
199 next_typeid: RefCell::new(0),
200 trace_stack: TraceStack::new(),
201 constraints: TypeConstraints::new(),
202 requirements: vec![],
203 replacements: ReplacementStack::new(),
204 generic_lists: HashMap::new(),
205 trait_impls: TraitImplList::new(),
206 checkpoints: vec![],
207 error_type: None,
208 pipeline_state: None,
209 diags: DiagList::new(),
210 };
211 result.error_type =
212 Some(result.add_type_var(TypeVar::Known(().nowhere(), KnownType::Error, vec![])));
213 result
214 }
215
216 pub fn create_child(&self) -> Self {
217 let mut result = self.clone();
218 result.key = fastrand::u64(..);
219 result.keys.insert(result.key);
220 result
221 }
222
223 pub fn add_type_var(&mut self, var: TypeVar) -> TypeVarID {
224 let idx = self.type_vars.len();
225 self.type_vars.push(var);
226 TypeVarID {
227 inner: idx,
228 type_state_key: self.key,
229 }
230 }
231
232 pub fn get_equations(&self) -> &TypeEquations {
233 &self.equations
234 }
235
236 pub fn get_constraints(&self) -> &TypeConstraints {
237 &self.constraints
238 }
239
240 pub fn get_generic_list<'a>(
242 &'a self,
243 generic_list_token: &'a GenericListToken,
244 ) -> Option<&'a HashMap<NameID, TypeVarID>> {
245 self.generic_lists.get(generic_list_token)
246 }
247
248 #[tracing::instrument(level = "trace", skip_all)]
249 fn hir_type_expr_to_var<'a>(
250 &'a mut self,
251 e: &Loc<hir::TypeExpression>,
252 generic_list_token: &GenericListToken,
253 ) -> Result<TypeVarID> {
254 let id = match &e.inner {
255 hir::TypeExpression::Integer(i) => self.add_type_var(TypeVar::Known(
256 e.loc(),
257 KnownType::Integer(i.clone()),
258 vec![],
259 )),
260 hir::TypeExpression::String(s) => self.add_type_var(TypeVar::Known(
261 e.loc(),
262 KnownType::String(s.clone()),
263 vec![],
264 )),
265 hir::TypeExpression::TypeSpec(spec) => {
266 self.type_var_from_hir(e.loc(), &spec.clone(), generic_list_token)?
267 }
268 hir::TypeExpression::ConstGeneric(g) => {
269 let constraint = self.visit_const_generic(g, generic_list_token)?;
270
271 let tvar = self.new_generic_tlnumber(e.loc());
272 self.add_constraint(
273 tvar.clone(),
274 constraint,
275 g.loc(),
276 &tvar,
277 ConstraintSource::Where,
278 );
279
280 tvar
281 }
282 };
283 Ok(id)
284 }
285
286 #[tracing::instrument(level = "trace", skip_all, fields(%hir_type))]
287 pub fn type_var_from_hir<'a>(
288 &'a mut self,
289 loc: Loc<()>,
290 hir_type: &crate::hir::TypeSpec,
291 generic_list_token: &GenericListToken,
292 ) -> Result<TypeVarID> {
293 let generic_list = self.get_generic_list(generic_list_token);
294 let var = match &hir_type {
295 hir::TypeSpec::Declared(base, params) => {
296 let params = params
297 .iter()
298 .map(|e| self.hir_type_expr_to_var(e, generic_list_token))
299 .collect::<Result<Vec<_>>>()?;
300
301 self.add_type_var(TypeVar::Known(
302 loc,
303 KnownType::Named(base.inner.clone()),
304 params,
305 ))
306 }
307 hir::TypeSpec::Generic(name) => match generic_list
308 .ok_or_else(|| diag_anyhow!(loc, "Found no generic list for {name}"))?
309 .get(&name.inner)
310 {
311 Some(t) => t.clone(),
312 None => {
313 for list_source in self.generic_lists.keys() {
314 info!("Generic lists exist for {list_source:?}");
315 }
316 info!("Current source is {generic_list_token:?}");
317 panic!("No entry in generic list for {name:?}");
318 }
319 },
320 hir::TypeSpec::Tuple(inner) => {
321 let inner = inner
322 .iter()
323 .map(|t| self.type_var_from_hir(loc, t, generic_list_token))
324 .collect::<Result<_>>()?;
325 self.add_type_var(TypeVar::tuple(loc, inner))
326 }
327 hir::TypeSpec::Array { inner, size } => {
328 let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
329 let size = self.hir_type_expr_to_var(size, generic_list_token)?;
330
331 self.add_type_var(TypeVar::array(loc, inner, size))
332 }
333 hir::TypeSpec::Wire(inner) => {
334 let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
335 self.add_type_var(TypeVar::wire(loc, inner))
336 }
337 hir::TypeSpec::Inverted(inner) => {
338 let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
339 self.add_type_var(TypeVar::inverted(loc, inner))
340 }
341 hir::TypeSpec::Wildcard(_) => self.new_generic_any(),
342 hir::TypeSpec::TraitSelf(_) => {
343 diag_bail!(
344 loc,
345 "Trying to convert TraitSelf to type inference type var"
346 )
347 }
348 };
349
350 Ok(var)
351 }
352
353 pub fn type_of(&self, expr: &TypedExpression) -> TypeVarID {
356 if let Some(t) = self.equations.get(expr) {
357 *t
358 } else {
359 panic!("Tried looking up the type of {expr:?} but it was not found")
360 }
361 }
362
363 pub fn maybe_type_of(&self, expr: &TypedExpression) -> Option<&TypeVarID> {
364 self.equations.get(expr)
365 }
366
367 pub fn new_generic_int(&mut self, loc: Loc<()>, symtab: &SymbolTable) -> TypeVar {
368 TypeVar::Known(loc, t_int(symtab), vec![self.new_generic_tluint(loc)])
369 }
370
371 pub fn new_concrete_int(&mut self, size: BigUint, loc: Loc<()>) -> TypeVarID {
372 TypeVar::Known(loc, KnownType::Integer(size.to_bigint()), vec![]).insert(self)
373 }
374
375 pub fn new_split_generic_int(
378 &mut self,
379 loc: Loc<()>,
380 symtab: &SymbolTable,
381 ) -> (TypeVarID, TypeVarID) {
382 let size = self.new_generic_tlint(loc);
383 let full = self.add_type_var(TypeVar::Known(loc, t_int(symtab), vec![size.clone()]));
384 (full, size)
385 }
386
387 pub fn new_split_generic_uint(
388 &mut self,
389 loc: Loc<()>,
390 symtab: &SymbolTable,
391 ) -> (TypeVarID, TypeVarID) {
392 let size = self.new_generic_tluint(loc);
393 let full = self.add_type_var(TypeVar::Known(loc, t_uint(symtab), vec![size.clone()]));
394 (full, size)
395 }
396
397 pub fn new_generic_with_meta(&mut self, loc: Loc<()>, meta: MetaType) -> TypeVarID {
398 let id = self.new_typeid();
399 self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), meta))
400 }
401
402 pub fn new_generic_type(&mut self, loc: Loc<()>) -> TypeVarID {
403 let id = self.new_typeid();
404 self.add_type_var(TypeVar::Unknown(
405 loc,
406 id,
407 TraitList::empty(),
408 MetaType::Type,
409 ))
410 }
411
412 pub fn new_generic_any(&mut self) -> TypeVarID {
413 let id = self.new_typeid();
414 self.add_type_var(TypeVar::Unknown(
417 ().nowhere(),
418 id,
419 TraitList::empty(),
420 MetaType::Any,
421 ))
422 }
423
424 pub fn new_generic_tlbool(&mut self, loc: Loc<()>) -> TypeVarID {
425 let id = self.new_typeid();
426 self.add_type_var(TypeVar::Unknown(
427 loc,
428 id,
429 TraitList::empty(),
430 MetaType::Bool,
431 ))
432 }
433
434 pub fn new_generic_tlstr(&mut self, loc: Loc<()>) -> TypeVarID {
435 let id = self.new_typeid();
436 self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Str))
437 }
438
439 pub fn new_generic_tluint(&mut self, loc: Loc<()>) -> TypeVarID {
440 let id = self.new_typeid();
441 self.add_type_var(TypeVar::Unknown(
442 loc,
443 id,
444 TraitList::empty(),
445 MetaType::Uint,
446 ))
447 }
448
449 pub fn new_generic_tlint(&mut self, loc: Loc<()>) -> TypeVarID {
450 let id = self.new_typeid();
451 self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Int))
452 }
453
454 pub fn new_generic_tlnumber(&mut self, loc: Loc<()>) -> TypeVarID {
455 let id = self.new_typeid();
456 self.add_type_var(TypeVar::Unknown(
457 loc,
458 id,
459 TraitList::empty(),
460 MetaType::Number,
461 ))
462 }
463
464 pub fn new_generic_number(&mut self, loc: Loc<()>, ctx: &Context) -> (TypeVarID, TypeVarID) {
465 let number = ctx
466 .symtab
467 .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
468 .expect("Did not find number in symtab")
469 .0;
470 let id = self.new_typeid();
471 let size = self.new_generic_tluint(loc);
472 let t = TraitReq {
473 name: TraitName::Named(number.nowhere()),
474 type_params: vec![size.clone()],
475 }
476 .nowhere();
477 (
478 self.add_type_var(TypeVar::Unknown(
479 loc,
480 id,
481 TraitList::from_vec(vec![t]),
482 MetaType::Type,
483 )),
484 size,
485 )
486 }
487
488 pub fn new_generic_with_traits(&mut self, loc: Loc<()>, traits: TraitList) -> TypeVarID {
489 let id = self.new_typeid();
490 self.add_type_var(TypeVar::Unknown(loc, id, traits, MetaType::Type))
491 }
492
493 pub fn get_pipeline_state<T>(&self, access_loc: &Loc<T>) -> Result<&PipelineState> {
497 self.pipeline_state
498 .as_ref()
499 .ok_or_else(|| diag_anyhow!(access_loc, "Expected to have a pipeline state"))
500 }
501
502 pub fn visit_unit_with_preprocessing(
504 &mut self,
505 entity: &Loc<Unit>,
506 pp: impl Fn(&mut TypeState, &Loc<Unit>, &GenericListToken, &Context) -> Result<()>,
507 ctx: &Context,
508 ) -> Result<()> {
509 self.trait_impls = ctx.trait_impls.clone();
510
511 let generic_list = self.create_generic_list(
512 GenericListSource::Definition(&entity.name.name_id().inner),
513 &entity.head.unit_type_params,
514 &entity.head.scope_type_params,
515 None,
516 &entity.head.where_clauses,
519 )?;
520
521 pp(self, entity, &generic_list, ctx)?;
522
523 for (name, t) in &entity.inputs {
525 let tvar = self.type_var_from_hir(t.loc(), t, &generic_list)?;
526 self.add_equation(TypedExpression::Name(name.inner.clone()), tvar)
527 }
528
529 if let UnitKind::Pipeline {
530 depth,
531 depth_typeexpr_id,
532 } = &entity.head.unit_kind.inner
533 {
534 self.setup_pipeline_state(
535 &entity.head.unit_kind,
536 &entity.body,
537 &generic_list,
538 depth,
539 depth_typeexpr_id,
540 )?;
541
542 let clock_index = if entity.head.is_nonstatic_method {
543 1
544 } else {
545 0
546 };
547
548 TypedExpression::Name(entity.inputs[clock_index].0.clone().inner)
549 .unify_with(&self.t_clock(entity.head.unit_kind.loc(), ctx.symtab), self)
550 .commit(self, ctx)
551 .into_diagnostic(
552 entity.inputs[0].1.loc(),
553 |diag,
554 Tm {
555 g: got,
556 e: _expected,
557 }| {
558 diag.message(format!(
559 "First pipeline argument must be a clock. Got {}",
560 got.display(self)
561 ))
562 .primary_label("expected clock")
563 },
564 self,
565 )?;
566 self.check_requirements(false, ctx)?;
569 }
570
571 self.visit_expression(&entity.body, ctx, &generic_list);
572
573 if let Some(output_type) = &entity.head.output_type {
575 let tvar = self.type_var_from_hir(output_type.loc(), output_type, &generic_list)?;
576
577 self.trace_stack.push(TraceStackEntry::Message(format!(
578 "Unifying with output type {}",
579 tvar.debug_resolve(self)
580 )));
581 self.unify(&TypedExpression::Id(entity.body.inner.id), &tvar, ctx)
582 .into_diagnostic_no_expected_source(
583 &entity.body,
584 |diag,
585 Tm {
586 g: got,
587 e: expected,
588 }| {
589 let expected = expected.display(self);
590 let got = got.display(self);
591 diag.message(format!(
594 "Output type mismatch. Expected {expected}, got {got}"
595 ))
596 .primary_label(format!("Found type {got}"))
597 .secondary_label(output_type, format!("{expected} type specified here"))
598 },
599 self,
600 )?;
601 } else {
602 TypedExpression::Id(entity.body.inner.id)
604 .unify_with(&self.add_type_var(TypeVar::unit(entity.head.name.loc())), self)
605 .commit(self, ctx)
606 .into_diagnostic_no_expected_source(entity.body.loc(), |diag, Tm{g: got, e: _expected}| {
607 diag.message("Output type mismatch")
608 .primary_label(format!("Found type {got}", got = got.display(self)))
609 .note(format!(
610 "The {} does not specify a return type.\nAdd a return type, or remove the return value.",
611 entity.head.unit_kind.name()
612 ))
613 }, self)?;
614 }
615
616 if let Some(PipelineState {
617 current_stage_depth,
618 pipeline_loc,
619 total_depth,
620 }) = self.pipeline_state.clone()
621 {
622 self.unify(&total_depth.inner, ¤t_stage_depth, ctx)
623 .into_diagnostic_no_expected_source(
624 pipeline_loc,
625 |diag, tm| {
626 let (e, g) = tm.display_e_g(self);
627 diag.message(format!("Pipeline depth mismatch. Expected {g} got {e}"))
628 .primary_label(format!("Found {e} stages in this pipeline"))
629 },
630 self,
631 )?;
632 }
633
634 self.check_requirements(true, ctx)?;
635
636 self.pipeline_state = None;
641
642 Ok(())
643 }
644
645 fn setup_pipeline_state(
646 &mut self,
647 unit_kind: &Loc<UnitKind>,
648 body_loc: &Loc<Expression>,
649 generic_list: &GenericListToken,
650 depth: &Loc<TypeExpression>,
651 depth_typeexpr_id: &ExprID,
652 ) -> Result<()> {
653 let depth_var = self.hir_type_expr_to_var(depth, generic_list)?;
654 self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth_var.clone());
655 self.pipeline_state = Some(PipelineState {
656 current_stage_depth: self.add_type_var(TypeVar::Known(
657 unit_kind.loc(),
658 KnownType::Integer(BigInt::zero()),
659 vec![],
660 )),
661 pipeline_loc: body_loc.loc(),
662 total_depth: depth_var.clone().at_loc(depth),
663 });
664 self.add_requirement(Requirement::PositivePipelineDepth {
665 depth: depth_var.at_loc(depth),
666 });
667 Ok(())
668 }
669
670 #[trace_typechecker]
671 #[tracing::instrument(level = "trace", skip_all, fields(%entity.name))]
672 pub fn visit_unit(&mut self, entity: &Loc<Unit>, ctx: &Context) -> Result<()> {
673 self.visit_unit_with_preprocessing(entity, |_, _, _, _| Ok(()), ctx)
674 }
675
676 #[trace_typechecker]
677 #[tracing::instrument(level = "trace", skip_all)]
678 fn visit_argument_list(
679 &mut self,
680 args: &Loc<ArgumentList<Expression>>,
681 ctx: &Context,
682 generic_list: &GenericListToken,
683 ) -> Result<()> {
684 for expr in args.expressions() {
685 self.visit_expression(expr, ctx, generic_list);
686 }
687 Ok(())
688 }
689
690 #[trace_typechecker]
691 fn type_check_argument_list(
692 &mut self,
693 args: &[Argument<Expression, TypeSpec>],
694 ctx: &Context,
695 generic_list: &GenericListToken,
696 ) -> Result<()> {
697 for Argument {
698 target,
699 target_type,
700 value,
701 kind,
702 } in args.iter()
703 {
704 let target_type = self.type_var_from_hir(value.loc(), target_type, generic_list)?;
705
706 let loc = match kind {
707 hir::param_util::ArgumentKind::Positional => value.loc(),
708 hir::param_util::ArgumentKind::Named => value.loc(),
709 hir::param_util::ArgumentKind::ShortNamed => target.loc(),
710 };
711
712 self.unify(&value.inner, &target_type, ctx)
713 .into_diagnostic(
714 loc,
715 |d, tm| {
716 let (expected, got) = tm.display_e_g(self);
717 d.message(format!(
718 "Argument type mismatch. Expected {expected} got {got}"
719 ))
720 .primary_label(format!("expected {expected}"))
721 },
722 self,
723 )?;
724 }
725
726 Ok(())
727 }
728
729 #[trace_typechecker]
730 pub fn visit_expression_result(
731 &mut self,
732 expression: &Loc<Expression>,
733 ctx: &Context,
734 generic_list: &GenericListToken,
735 new_type: TypeVarID,
736 ) -> Result<()> {
737 match &expression.inner.kind {
739 ExprKind::Error => {
740 new_type
741 .unify_with(&self.t_err(expression.loc()), self)
742 .commit(self, ctx)
743 .unwrap();
744 }
745 ExprKind::Identifier(_) => self.visit_identifier(expression, ctx)?,
746 ExprKind::TypeLevelInteger(_) => {
747 self.visit_type_level_integer(expression, generic_list, ctx)?
748 }
749 ExprKind::IntLiteral(_, _) => self.visit_int_literal(expression, ctx)?,
750 ExprKind::BoolLiteral(_) => self.visit_bool_literal(expression, ctx)?,
751 ExprKind::TriLiteral(_) => self.visit_tri_literal(expression, ctx)?,
752 ExprKind::TupleLiteral(_) => self.visit_tuple_literal(expression, ctx, generic_list)?,
753 ExprKind::TupleIndex(_, _) => self.visit_tuple_index(expression, ctx, generic_list)?,
754 ExprKind::ArrayLiteral(_) => self.visit_array_literal(expression, ctx, generic_list)?,
755 ExprKind::ArrayShorthandLiteral(_, _) => {
756 self.visit_array_shorthand_literal(expression, ctx, generic_list)?
757 }
758 ExprKind::CreatePorts => self.visit_create_ports(expression, ctx, generic_list)?,
759 ExprKind::FieldAccess(_, _) => {
760 self.visit_field_access(expression, ctx, generic_list)?
761 }
762 ExprKind::MethodCall { .. } => self.visit_method_call(expression, ctx, generic_list)?,
763 ExprKind::Index(_, _) => self.visit_index(expression, ctx, generic_list)?,
764 ExprKind::RangeIndex { .. } => self.visit_range_index(expression, ctx, generic_list)?,
765 ExprKind::Block(_) => self.visit_block_expr(expression, ctx, generic_list)?,
766 ExprKind::If(_, _, _) => self.visit_if(expression, ctx, generic_list)?,
767 ExprKind::Match(_, _) => self.visit_match(expression, ctx, generic_list)?,
768 ExprKind::BinaryOperator(_, _, _) => {
769 self.visit_binary_operator(expression, ctx, generic_list)?
770 }
771 ExprKind::UnaryOperator(_, _) => {
772 self.visit_unary_operator(expression, ctx, generic_list)?
773 }
774 ExprKind::Call {
775 kind,
776 callee,
777 args,
778 turbofish,
779 safety: _,
780 } => {
781 let head = ctx.symtab.unit_by_id(&callee.inner);
782
783 self.handle_function_like(
784 expression.map_ref(|e| e.id),
785 &expression.get_type(self),
786 &FunctionLikeName::Free(callee.inner.clone()),
787 &head,
788 kind,
789 args,
790 ctx,
791 true,
792 false,
793 turbofish.as_ref().map(|turbofish| TurbofishCtx {
794 turbofish,
795 prev_generic_list: generic_list,
796 type_ctx: ctx,
797 }),
798 generic_list,
799 )?;
800 }
801 ExprKind::PipelineRef { .. } => {
802 self.visit_pipeline_ref(expression, generic_list, ctx)?;
803 }
804 ExprKind::StageReady | ExprKind::StageValid => {
805 expression
806 .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
807 .commit(self, ctx)
808 .into_default_diagnostic(expression, self)?;
809 }
810
811 ExprKind::TypeLevelIf(cond, on_true, on_false) => {
812 let cond_var = self.visit_const_generic_with_id(
813 cond,
814 generic_list,
815 ConstraintSource::TypeLevelIf,
816 ctx,
817 )?;
818 let t_bool = self.new_generic_tlbool(cond.loc());
819 self.unify(&cond_var, &t_bool, ctx).into_diagnostic(
820 cond,
821 |diag, tm| {
822 let (_e, g) = tm.display_e_g(self);
823 diag.message(format!("gen if conditions must be #bool, got {g}"))
824 },
825 self,
826 )?;
827
828 self.visit_expression(on_true, ctx, generic_list);
829 self.visit_expression(on_false, ctx, generic_list);
830
831 self.unify_expression_generic_error(expression, on_true.as_ref(), ctx)?;
832 self.unify_expression_generic_error(expression, on_false.as_ref(), ctx)?;
833 }
834 ExprKind::LambdaDef {
835 unit_kind,
836 arguments,
837 body,
838 lambda_type,
839 type_params,
840 outer_generic_params,
841 lambda_unit: _,
842 clock,
843 captures,
844 } => {
845 for arg in arguments {
846 self.visit_pattern(arg, ctx, generic_list)?;
847 }
848
849 if let Some(clock) = clock {
850 clock
851 .unify_with(&self.t_clock(clock.loc(), ctx.symtab), self)
852 .commit(self, ctx)
853 .into_default_diagnostic(clock, self)?;
854 }
855
856 let outer_pipeline_state = self.pipeline_state.take();
857 if let UnitKind::Pipeline {
858 depth,
859 depth_typeexpr_id,
860 } = &unit_kind.inner
861 {
862 self.setup_pipeline_state(
863 unit_kind,
864 body,
865 &generic_list,
866 depth,
867 depth_typeexpr_id,
868 )?;
869 }
870 self.visit_expression(body, ctx, generic_list);
871 self.pipeline_state = outer_pipeline_state;
872
873 let lambda_params = arguments
874 .iter()
875 .map(|arg| arg.get_type(self))
876 .chain(vec![body.get_type(self)])
877 .chain(captures.iter().map(|(_, cap_name)| cap_name.get_type(self)))
878 .chain(
879 outer_generic_params
880 .iter()
881 .map(|cap| {
882 let t = self
883 .get_generic_list(generic_list)
884 .ok_or_else(|| {
885 diag_anyhow!(
886 expression,
887 "Found a captured generic but no generic list"
888 )
889 })?
890 .get(&cap.name_in_body)
891 .ok_or_else(|| {
892 diag_anyhow!(
893 &cap.name_in_body,
894 "Did not find an entry for {} in lambda generic list",
895 cap.name_in_body
896 )
897 });
898 Ok(t?.clone())
899 })
900 .collect::<Result<Vec<_>>>()?
901 .into_iter(),
902 )
903 .collect::<Vec<_>>();
904
905 let self_type = TypeVar::Known(
906 expression.loc(),
907 KnownType::Named(lambda_type.clone()),
908 lambda_params.clone(),
909 );
910
911 let unit_generic_list = self.create_generic_list(
912 GenericListSource::Expression(expression.id),
913 &type_params.all().cloned().collect::<Vec<_>>(),
914 &[],
915 None,
916 &[],
917 )?;
918
919 for (p, tp) in lambda_params.iter().zip(type_params.all()) {
920 let gl = self.get_generic_list(&unit_generic_list).unwrap();
921 p.unify_with(
923 gl.get(&tp.name_id).ok_or_else(|| {
924 diag_anyhow!(
925 expression,
926 "Lambda unit list did not contain {}",
927 tp.name_id
928 )
929 })?,
930 self,
931 )
932 .commit(self, ctx)
933 .into_default_diagnostic(expression, self)?;
934 }
935 expression
936 .unify_with(&self.add_type_var(self_type), self)
937 .commit(self, ctx)
938 .into_default_diagnostic(expression, self)?;
939 }
940 ExprKind::StaticUnreachable(_) => {}
941 ExprKind::Null => {}
942 }
943 Ok(())
944 }
945
946 #[tracing::instrument(level = "trace", skip_all)]
947 pub fn visit_expression(
948 &mut self,
949 expression: &Loc<Expression>,
950 ctx: &Context,
951 generic_list: &GenericListToken,
952 ) {
953 let new_type = self.new_generic_type(expression.loc());
954 self.add_equation(TypedExpression::Id(expression.inner.id), new_type);
955
956 match self.visit_expression_result(expression, ctx, generic_list, new_type) {
957 Ok(_) => {}
958 Err(e) => {
959 new_type
960 .unify_with(&self.t_err(expression.loc()), self)
961 .commit(self, ctx)
962 .unwrap();
963
964 self.diags.errors.push(e);
965 }
966 }
967 }
968
969 #[tracing::instrument(level = "trace", skip_all, fields(%name))]
971 #[trace_typechecker]
972 fn handle_function_like(
973 &mut self,
974 expression_id: Loc<ExprID>,
975 expression_type: &TypeVarID,
976 name: &FunctionLikeName,
977 head: &Loc<UnitHead>,
978 call_kind: &CallKind,
979 args: &Loc<ArgumentList<Expression>>,
980 ctx: &Context,
981 visit_args: bool,
985 is_method: bool,
988 turbofish: Option<TurbofishCtx>,
989 generic_list: &GenericListToken,
990 ) -> Result<()> {
991 let unit_generic_list = self.create_generic_list(
993 GenericListSource::Expression(expression_id.inner),
994 &head.unit_type_params,
995 &head.scope_type_params,
996 turbofish,
997 &head.where_clauses,
998 )?;
999
1000 match (&head.unit_kind.inner, call_kind) {
1001 (
1002 UnitKind::Pipeline {
1003 depth: udepth,
1004 depth_typeexpr_id: _,
1005 },
1006 CallKind::Pipeline {
1007 inst_loc: _,
1008 depth: cdepth,
1009 depth_typeexpr_id: cdepth_typeexpr_id,
1010 },
1011 ) => {
1012 let definition_depth = self.hir_type_expr_to_var(udepth, &unit_generic_list)?;
1013 let call_depth = self.hir_type_expr_to_var(cdepth, generic_list)?;
1014
1015 self.add_equation(TypedExpression::Id(*cdepth_typeexpr_id), call_depth.clone());
1019
1020 self.unify(&call_depth, &definition_depth, ctx)
1021 .into_diagnostic_no_expected_source(
1022 cdepth,
1023 |diag, tm| {
1024 let (e, g) = tm.display_e_g(self);
1025 diag.message("Pipeline depth mismatch")
1026 .primary_label(format!("Expected depth {e}, got {g}"))
1027 .secondary_label(udepth, format!("{name} has depth {e}"))
1028 },
1029 self,
1030 )?;
1031 }
1032 _ => {}
1033 }
1034
1035 if visit_args {
1036 self.visit_argument_list(args, ctx, &generic_list)?;
1037 }
1038
1039 let type_params = &head.get_type_params();
1040
1041 macro_rules! handle_special_functions {
1043 ($([$($path:expr),*] => $handler:expr),*) => {
1044 $(
1045 let path = Path(vec![$(Identifier($path.to_string()).nowhere()),*]).nowhere();
1046 if ctx.symtab
1047 .try_lookup_final_id(&path, &[])
1048 .map(|n| &FunctionLikeName::Free(n) == name)
1049 .unwrap_or(false)
1050 {
1051 $handler
1052 };
1053 )*
1054 }
1055 }
1056
1057 macro_rules! generic_arg {
1060 ($idx:expr) => {
1061 self.get_generic_list(&unit_generic_list)
1062 .ok_or_else(|| diag_anyhow!(expression_id, "Found no generic list for call"))?
1063 [&type_params[$idx].name_id()]
1064 .clone()
1065 };
1066 }
1067
1068 let matched_args =
1069 match_args_with_params(args, &head.inputs.inner, is_method).map_err(|e| {
1070 let diag: Diagnostic = e.into();
1071 diag.secondary_label(
1072 head,
1073 format!("{kind} defined here", kind = head.unit_kind.name()),
1074 )
1075 })?;
1076
1077 handle_special_functions! {
1078 ["std", "conv", "concat"] => {
1079 self.handle_concat(
1080 expression_id,
1081 generic_arg!(0),
1082 generic_arg!(1),
1083 generic_arg!(2),
1084 &matched_args,
1085 ctx
1086 )?
1087 },
1088 ["std", "conv", "trunc"] => {
1089 self.handle_trunc(
1090 expression_id,
1091 generic_arg!(0),
1092 generic_arg!(1),
1093 &matched_args,
1094 ctx
1095 )?
1096 },
1097 ["std", "ops", "comb_div"] => {
1098 self.handle_comb_mod_or_div(
1099 generic_arg!(0),
1100 &matched_args,
1101 ctx
1102 )?
1103 },
1104 ["std", "ops", "comb_mod"] => {
1105 self.handle_comb_mod_or_div(
1106 generic_arg!(0),
1107 &matched_args,
1108 ctx
1109 )?
1110 },
1111 ["std", "mem", "clocked_memory"] => {
1112 let num_elements = generic_arg!(0);
1113 let addr_size = generic_arg!(2);
1114
1115 self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1116 },
1117 ["std", "mem", "clocked_memory_init"] => {
1120 let num_elements = generic_arg!(0);
1121 let addr_size = generic_arg!(2);
1122
1123 self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1124 },
1125 ["std", "mem", "read_memory"] => {
1126 let addr_size = generic_arg!(0);
1127 let num_elements = generic_arg!(2);
1128
1129 self.handle_read_memory(num_elements, addr_size, &matched_args, ctx)?
1130 }
1131 };
1132
1133 self.type_check_argument_list(&matched_args, ctx, &unit_generic_list)?;
1135
1136 let return_type = head
1137 .output_type
1138 .as_ref()
1139 .map(|o| self.type_var_from_hir(expression_id.loc(), o, &unit_generic_list))
1140 .transpose()?
1141 .unwrap_or_else(|| {
1142 self.add_type_var(TypeVar::Known(
1143 expression_id.loc(),
1144 KnownType::Tuple,
1145 vec![],
1146 ))
1147 });
1148
1149 self.unify(expression_type, &return_type, ctx)
1150 .into_default_diagnostic(expression_id.loc(), self)?;
1151
1152 Ok(())
1153 }
1154
1155 pub fn handle_concat(
1156 &mut self,
1157 expression_id: Loc<ExprID>,
1158 source_lhs_ty: TypeVarID,
1159 source_rhs_ty: TypeVarID,
1160 source_result_ty: TypeVarID,
1161 args: &[Argument<Expression, TypeSpec>],
1162 ctx: &Context,
1163 ) -> Result<()> {
1164 let (lhs_type, lhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1165 let (rhs_type, rhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1166 let (result_type, result_size) = self.new_generic_number(expression_id.loc(), ctx);
1167 self.unify(&source_lhs_ty, &lhs_type, ctx)
1168 .into_default_diagnostic(args[0].value.loc(), self)?;
1169 self.unify(&source_rhs_ty, &rhs_type, ctx)
1170 .into_default_diagnostic(args[1].value.loc(), self)?;
1171 self.unify(&source_result_ty, &result_type, ctx)
1172 .into_default_diagnostic(expression_id.loc(), self)?;
1173
1174 self.add_constraint(
1176 result_size.clone(),
1177 ce_var(&lhs_size) + ce_var(&rhs_size),
1178 expression_id.loc(),
1179 &result_size,
1180 ConstraintSource::Concatenation,
1181 );
1182 self.add_constraint(
1183 lhs_size.clone(),
1184 ce_var(&result_size) + -ce_var(&rhs_size),
1185 args[0].value.loc(),
1186 &lhs_size,
1187 ConstraintSource::Concatenation,
1188 );
1189 self.add_constraint(
1190 rhs_size.clone(),
1191 ce_var(&result_size) + -ce_var(&lhs_size),
1192 args[1].value.loc(),
1193 &rhs_size,
1194 ConstraintSource::Concatenation,
1195 );
1196
1197 self.add_requirement(Requirement::SharedBase(vec![
1198 lhs_type.at_loc(args[0].value),
1199 rhs_type.at_loc(args[1].value),
1200 result_type.at_loc(&expression_id.loc()),
1201 ]));
1202 Ok(())
1203 }
1204
1205 pub fn handle_trunc(
1206 &mut self,
1207 expression_id: Loc<ExprID>,
1208 source_in_ty: TypeVarID,
1209 source_result_ty: TypeVarID,
1210 args: &[Argument<Expression, TypeSpec>],
1211 ctx: &Context,
1212 ) -> Result<()> {
1213 let (in_ty, _) = self.new_generic_number(expression_id.loc(), ctx);
1214 let (result_type, _) = self.new_generic_number(expression_id.loc(), ctx);
1215 self.unify(&source_in_ty, &in_ty, ctx)
1216 .into_default_diagnostic(args[0].value.loc(), self)?;
1217 self.unify(&source_result_ty, &result_type, ctx)
1218 .into_default_diagnostic(expression_id.loc(), self)?;
1219
1220 self.add_requirement(Requirement::SharedBase(vec![
1221 in_ty.at_loc(args[0].value),
1222 result_type.at_loc(&expression_id.loc()),
1223 ]));
1224 Ok(())
1225 }
1226
1227 pub fn handle_comb_mod_or_div(
1228 &mut self,
1229 n_ty: TypeVarID,
1230 args: &[Argument<Expression, TypeSpec>],
1231 ctx: &Context,
1232 ) -> Result<()> {
1233 let (num, _) = self.new_generic_number(args[0].value.loc(), ctx);
1234 self.unify(&n_ty, &num, ctx)
1235 .into_default_diagnostic(args[0].value.loc(), self)?;
1236 Ok(())
1237 }
1238
1239 pub fn handle_clocked_memory(
1240 &mut self,
1241 num_elements: TypeVarID,
1242 addr_size_arg: TypeVarID,
1243 args: &[Argument<Expression, TypeSpec>],
1244 ctx: &Context,
1245 ) -> Result<()> {
1246 let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1249 let arg1_loc = args[1].value.loc();
1250 let tup = TypeVar::tuple(
1251 args[1].value.loc(),
1252 vec![
1253 self.new_generic_type(arg1_loc),
1254 addr_type,
1255 self.new_generic_type(arg1_loc),
1256 ],
1257 );
1258 let port_type = TypeVar::array(
1259 arg1_loc,
1260 self.add_type_var(tup),
1261 self.new_generic_tluint(arg1_loc),
1262 )
1263 .insert(self);
1264
1265 self.add_constraint(
1266 addr_size.clone(),
1267 bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1268 args[1].value.loc(),
1269 &port_type,
1270 ConstraintSource::MemoryIndexing,
1271 );
1272
1273 self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1275 self.unify_expression_generic_error(args[1].value, &port_type, ctx)?;
1276
1277 Ok(())
1278 }
1279
1280 pub fn handle_read_memory(
1281 &mut self,
1282 num_elements: TypeVarID,
1283 addr_size_arg: TypeVarID,
1284 args: &[Argument<Expression, TypeSpec>],
1285 ctx: &Context,
1286 ) -> Result<()> {
1287 let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1288
1289 self.add_constraint(
1290 addr_size.clone(),
1291 bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1292 args[1].value.loc(),
1293 &addr_type,
1294 ConstraintSource::MemoryIndexing,
1295 );
1296
1297 self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1299
1300 Ok(())
1301 }
1302
1303 #[tracing::instrument(level = "trace", skip(self, turbofish, where_clauses))]
1304 pub fn create_generic_list(
1305 &mut self,
1306 source: GenericListSource,
1307 type_params: &[Loc<TypeParam>],
1308 scope_type_params: &[Loc<TypeParam>],
1309 turbofish: Option<TurbofishCtx>,
1310 where_clauses: &[Loc<WhereClause>],
1311 ) -> Result<GenericListToken> {
1312 let turbofish_params = if let Some(turbofish) = turbofish.as_ref() {
1313 if type_params.is_empty() {
1314 return Err(Diagnostic::error(
1315 turbofish.turbofish,
1316 "Turbofish on non-generic function",
1317 )
1318 .primary_label("Turbofish on non-generic function"));
1319 }
1320
1321 let matched_params =
1322 param_util::match_args_with_params(turbofish.turbofish, &type_params, false)?;
1323
1324 matched_params
1328 .iter()
1329 .map(|matched_param| {
1330 let i = type_params
1331 .iter()
1332 .enumerate()
1333 .find_map(|(i, param)| match ¶m.inner {
1334 TypeParam {
1335 ident,
1336 name_id: _,
1337 trait_bounds: _,
1338 meta: _,
1339 } => {
1340 if ident == matched_param.target {
1341 Some(i)
1342 } else {
1343 None
1344 }
1345 }
1346 })
1347 .unwrap();
1348 (i, matched_param)
1349 })
1350 .sorted_by_key(|(i, _)| *i)
1351 .map(|(_, mp)| Some(mp.value))
1352 .collect::<Vec<_>>()
1353 } else {
1354 type_params.iter().map(|_| None).collect::<Vec<_>>()
1355 };
1356
1357 let mut inline_trait_bounds: Vec<Loc<WhereClause>> = vec![];
1358
1359 let scope_type_params = scope_type_params
1360 .iter()
1361 .map(|param| {
1362 let hir::TypeParam {
1363 ident,
1364 name_id,
1365 trait_bounds,
1366 meta,
1367 } = ¶m.inner;
1368 if !trait_bounds.is_empty() {
1369 if let MetaType::Type = meta {
1370 inline_trait_bounds.push(
1371 WhereClause::Type {
1372 target: name_id.clone().at_loc(ident),
1373 traits: trait_bounds.clone(),
1374 }
1375 .at_loc(param),
1376 );
1377 } else {
1378 return Err(Diagnostic::bug(param, "Trait bounds on generic int")
1379 .primary_label("Trait bounds are only allowed on type parameters"));
1380 }
1381 }
1382 Ok((
1383 name_id.clone(),
1384 self.new_generic_with_meta(param.loc(), meta.clone()),
1385 ))
1386 })
1387 .collect::<Result<Vec<_>>>()?;
1388
1389 let new_list = type_params
1390 .iter()
1391 .enumerate()
1392 .map(|(i, param)| {
1393 let hir::TypeParam {
1394 ident,
1395 name_id,
1396 trait_bounds,
1397 meta,
1398 } = ¶m.inner;
1399
1400 let t = self.new_generic_with_meta(param.loc(), meta.clone());
1401
1402 if let Some(tf) = &turbofish_params[i] {
1403 let tf_ctx = turbofish.as_ref().unwrap();
1404 let ty = self.hir_type_expr_to_var(tf, tf_ctx.prev_generic_list)?;
1405 self.unify(&ty, &t, tf_ctx.type_ctx)
1406 .into_default_diagnostic(param, self)?;
1407 }
1408
1409 if !trait_bounds.is_empty() {
1410 if let MetaType::Type = meta {
1411 inline_trait_bounds.push(
1412 WhereClause::Type {
1413 target: name_id.clone().at_loc(ident),
1414 traits: trait_bounds.clone(),
1415 }
1416 .at_loc(param),
1417 );
1418 }
1419 Ok((name_id.clone(), t))
1420 } else {
1421 Ok((name_id.clone(), t))
1422 }
1423 })
1424 .collect::<Result<Vec<_>>>()?
1425 .into_iter()
1426 .chain(scope_type_params.into_iter())
1427 .map(|(name, t)| (name, t.clone()))
1428 .collect::<HashMap<_, _>>();
1429
1430 self.trace_stack.push(TraceStackEntry::NewGenericList(
1431 new_list
1432 .iter()
1433 .map(|(name, var)| (name.clone(), var.debug_resolve(self)))
1434 .collect(),
1435 ));
1436
1437 let token = self.add_mapped_generic_list(source, new_list.clone());
1438
1439 for constraint in where_clauses.iter().chain(inline_trait_bounds.iter()) {
1440 match &constraint.inner {
1441 WhereClause::Type { target, traits } => {
1442 self.visit_trait_bounds(target, traits.as_slice(), &token)?;
1443 }
1444 WhereClause::Int { target, constraint } => {
1445 let int_constraint = self.visit_const_generic(constraint, &token)?;
1446 let tvar = new_list.get(target).ok_or_else(|| {
1447 Diagnostic::error(
1448 target,
1449 format!("{target} is not a generic parameter on this unit"),
1450 )
1451 .primary_label("Not a generic parameter")
1452 })?;
1453
1454 self.add_constraint(
1455 tvar.clone(),
1456 int_constraint,
1457 constraint.loc(),
1458 &tvar,
1459 ConstraintSource::Where,
1460 );
1461 }
1462 }
1463 }
1464
1465 Ok(token)
1466 }
1467
1468 pub fn add_mapped_generic_list(
1470 &mut self,
1471 source: GenericListSource,
1472 mapping: HashMap<NameID, TypeVarID>,
1473 ) -> GenericListToken {
1474 let reference = match source {
1475 GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1476 GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1477 GenericListSource::ImplBlock { target, id } => {
1478 GenericListToken::ImplBlock(target.clone(), id)
1479 }
1480 GenericListSource::Expression(id) => GenericListToken::Expression(id),
1481 };
1482
1483 if self
1484 .generic_lists
1485 .insert(reference.clone(), mapping)
1486 .is_some()
1487 {
1488 panic!("A generic list already existed for {reference:?}");
1489 }
1490 reference
1491 }
1492
1493 pub fn remove_generic_list(&mut self, source: GenericListSource) {
1494 let reference = match source {
1495 GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1496 GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1497 GenericListSource::ImplBlock { target, id } => {
1498 GenericListToken::ImplBlock(target.clone(), id)
1499 }
1500 GenericListSource::Expression(id) => GenericListToken::Expression(id),
1501 };
1502
1503 self.generic_lists.remove(&reference.clone());
1504 }
1505
1506 #[tracing::instrument(level = "trace", skip_all)]
1507 #[trace_typechecker]
1508 pub fn visit_block(
1509 &mut self,
1510 block: &Block,
1511 ctx: &Context,
1512 generic_list: &GenericListToken,
1513 ) -> Result<()> {
1514 for statement in &block.statements {
1515 self.visit_statement(statement, ctx, generic_list);
1516 }
1517 if let Some(result) = &block.result {
1518 self.visit_expression(result, ctx, generic_list);
1519 }
1520 Ok(())
1521 }
1522
1523 #[tracing::instrument(level = "trace", skip_all)]
1524 pub fn visit_impl_blocks(&mut self, item_list: &ItemList) -> TraitImplList {
1525 let mut trait_impls = TraitImplList::new();
1526 for (target, impls) in &item_list.impls.inner {
1527 for (trait_name, impls) in impls {
1528 for (target_args, impls) in impls {
1529 for (trait_args, impl_block) in impls {
1530 let result =
1531 (|| {
1532 let generic_list = self.create_generic_list(
1533 GenericListSource::ImplBlock {
1534 target,
1535 id: impl_block.id,
1536 },
1537 &[],
1538 impl_block.type_params.as_slice(),
1539 None,
1540 &[],
1541 )?;
1542
1543 let loc = trait_name
1544 .name_loc()
1545 .map(|n| ().at_loc(&n))
1546 .unwrap_or(().at_loc(&impl_block));
1547
1548 let trait_type_params = trait_args
1549 .iter()
1550 .map(|param| {
1551 Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1552 ¶m.clone().at_loc(&loc),
1553 &generic_list,
1554 )?))
1555 })
1556 .collect::<Result<_>>()?;
1557
1558 let target_type_params = target_args
1559 .iter()
1560 .map(|param| {
1561 Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1562 ¶m.clone().at_loc(&loc),
1563 &generic_list,
1564 )?))
1565 })
1566 .collect::<Result<_>>()?;
1567
1568 trait_impls.inner.entry(target.clone()).or_default().push(
1569 TraitImpl {
1570 name: trait_name.clone(),
1571 target_type_params,
1572 trait_type_params,
1573
1574 impl_block: impl_block.inner.clone(),
1575 },
1576 );
1577
1578 Ok(())
1579 })();
1580
1581 match result {
1582 Ok(()) => {}
1583 Err(e) => self.diags.errors.push(e),
1584 }
1585 }
1586 }
1587 }
1588 }
1589
1590 trait_impls
1591 }
1592
1593 #[trace_typechecker]
1594 pub fn visit_pattern(
1595 &mut self,
1596 pattern: &Loc<Pattern>,
1597 ctx: &Context,
1598 generic_list: &GenericListToken,
1599 ) -> Result<()> {
1600 let new_type = self.new_generic_type(pattern.loc());
1601 self.add_equation(TypedExpression::Id(pattern.inner.id), new_type);
1602 match &pattern.inner.kind {
1603 hir::PatternKind::Integer(val) => {
1604 let (num_t, _) = &self.new_generic_number(pattern.loc(), ctx);
1605 self.add_requirement(Requirement::FitsIntLiteral {
1606 value: ConstantInt::Literal(val.clone()),
1607 target_type: num_t.clone().at_loc(pattern),
1608 });
1609 self.unify(pattern, num_t, ctx)
1610 .expect("Failed to unify new_generic with int");
1611 }
1612 hir::PatternKind::Bool(_) => {
1613 pattern
1614 .unify_with(&self.t_bool(pattern.loc(), ctx.symtab), self)
1615 .commit(self, ctx)
1616 .expect("Expected new_generic with boolean");
1617 }
1618 hir::PatternKind::Name { name, pre_declared } => {
1619 if !pre_declared {
1620 self.add_equation(
1621 TypedExpression::Name(name.clone().inner),
1622 pattern.get_type(self),
1623 );
1624 }
1625 self.unify(
1626 &TypedExpression::Id(pattern.id),
1627 &TypedExpression::Name(name.clone().inner),
1628 ctx,
1629 )
1630 .into_default_diagnostic(name.loc(), self)?;
1631 }
1632 hir::PatternKind::Tuple(subpatterns) => {
1633 for pattern in subpatterns {
1634 self.visit_pattern(pattern, ctx, generic_list)?;
1635 }
1636 let tuple_type = self.add_type_var(TypeVar::tuple(
1637 pattern.loc(),
1638 subpatterns
1639 .iter()
1640 .map(|pattern| {
1641 let p_type = pattern.get_type(self);
1642 Ok(p_type)
1643 })
1644 .collect::<Result<_>>()?,
1645 ));
1646
1647 self.unify(pattern, &tuple_type, ctx)
1648 .expect("Unification of new_generic with tuple type cannot fail");
1649 }
1650 hir::PatternKind::Array(inner) => {
1651 for pattern in inner {
1652 self.visit_pattern(pattern, ctx, generic_list)?;
1653 }
1654 if inner.len() == 0 {
1655 return Err(
1656 Diagnostic::error(pattern, "Empty array patterns are unsupported")
1657 .primary_label("Empty array pattern"),
1658 );
1659 } else {
1660 let inner_t = inner[0].get_type(self);
1661
1662 for pattern in inner.iter().skip(1) {
1663 self.unify(pattern, &inner_t, ctx)
1664 .into_default_diagnostic(pattern, self)?;
1665 }
1666
1667 pattern
1668 .unify_with(
1669 &TypeVar::Known(
1670 pattern.loc(),
1671 KnownType::Array,
1672 vec![
1673 inner_t,
1674 self.add_type_var(TypeVar::Known(
1675 pattern.loc(),
1676 KnownType::Integer(inner.len().to_bigint()),
1677 vec![],
1678 )),
1679 ],
1680 )
1681 .insert(self),
1682 self,
1683 )
1684 .commit(self, ctx)
1685 .into_default_diagnostic(pattern, self)?;
1686 }
1687 }
1688 hir::PatternKind::Type(name, args) => {
1689 let (condition_type, params, generic_list) =
1690 match ctx.symtab.patternable_type_by_id(name).inner {
1691 Patternable {
1692 kind: PatternableKind::Enum,
1693 params: _,
1694 } => {
1695 let enum_variant = ctx.symtab.enum_variant_by_id(name).inner;
1696 let generic_list = self.create_generic_list(
1697 GenericListSource::Anonymous,
1698 &enum_variant.type_params,
1699 &[],
1700 None,
1701 &[],
1702 )?;
1703
1704 let condition_type = self.type_var_from_hir(
1705 pattern.loc(),
1706 &enum_variant.output_type,
1707 &generic_list,
1708 )?;
1709
1710 (condition_type, enum_variant.params, generic_list)
1711 }
1712 Patternable {
1713 kind: PatternableKind::Struct,
1714 params: _,
1715 } => {
1716 let s = ctx.symtab.struct_by_id(name).inner;
1717 let generic_list = self.create_generic_list(
1718 GenericListSource::Anonymous,
1719 &s.type_params,
1720 &[],
1721 None,
1722 &[],
1723 )?;
1724
1725 let condition_type =
1726 self.type_var_from_hir(pattern.loc(), &s.self_type, &generic_list)?;
1727
1728 (condition_type, s.params, generic_list)
1729 }
1730 };
1731
1732 self.unify(pattern, &condition_type, ctx)
1733 .expect("Unification of new_generic with enum cannot fail");
1734
1735 for (
1736 PatternArgument {
1737 target,
1738 value: pattern,
1739 kind,
1740 },
1741 Parameter {
1742 name: _,
1743 ty: target_type,
1744 no_mangle: _,
1745 field_translator: _,
1746 },
1747 ) in args.iter().zip(params.0.iter())
1748 {
1749 self.visit_pattern(pattern, ctx, &generic_list)?;
1750 let target_type =
1751 self.type_var_from_hir(target_type.loc(), target_type, &generic_list)?;
1752
1753 let loc = match kind {
1754 hir::ArgumentKind::Positional => pattern.loc(),
1755 hir::ArgumentKind::Named => pattern.loc(),
1756 hir::ArgumentKind::ShortNamed => target.loc(),
1757 };
1758
1759 self.unify(pattern, &target_type, ctx).into_diagnostic(
1760 loc,
1761 |d, tm| {
1762 let (expected, got) = tm.display_e_g(self);
1763 d.message(format!(
1764 "Argument type mismatch. Expected {expected} got {got}"
1765 ))
1766 .primary_label(format!("expected {expected}"))
1767 },
1768 self,
1769 )?;
1770 }
1771 }
1772 }
1773 Ok(())
1774 }
1775
1776 #[trace_typechecker]
1777 pub fn visit_wal_trace(
1778 &mut self,
1779 trace: &Loc<WalTrace>,
1780 ctx: &Context,
1781 generic_list: &GenericListToken,
1782 ) -> Result<()> {
1783 let WalTrace { clk, rst } = &trace.inner;
1784 clk.as_ref()
1785 .map(|x| {
1786 self.visit_expression(x, ctx, generic_list);
1787 x.unify_with(&self.t_clock(trace.loc(), ctx.symtab), self)
1788 .commit(self, ctx)
1789 .into_default_diagnostic(x, self)
1790 })
1791 .transpose()?;
1792 rst.as_ref()
1793 .map(|x| {
1794 self.visit_expression(x, ctx, generic_list);
1795 x.unify_with(&self.t_bool(trace.loc(), ctx.symtab), self)
1796 .commit(self, ctx)
1797 .into_default_diagnostic(x, self)
1798 })
1799 .transpose()?;
1800 Ok(())
1801 }
1802
1803 #[trace_typechecker]
1804 pub fn visit_statement_error(
1805 &mut self,
1806 stmt: &Loc<Statement>,
1807 ctx: &Context,
1808 generic_list: &GenericListToken,
1809 ) -> Result<()> {
1810 match &stmt.inner {
1811 Statement::Error => {
1812 if let Some(current_stage_depth) =
1813 self.pipeline_state.as_ref().map(|s| s.current_stage_depth)
1814 {
1815 current_stage_depth
1816 .unify_with(&self.t_err(stmt.loc()), self)
1817 .commit(self, ctx)
1818 .unwrap();
1819 }
1820 Ok(())
1821 }
1822 Statement::Binding(Binding {
1823 pattern,
1824 ty,
1825 value,
1826 wal_trace,
1827 }) => {
1828 trace!("Visiting `let {} = ..`", pattern.kind);
1829 self.visit_expression(value, ctx, generic_list);
1830
1831 self.visit_pattern(pattern, ctx, generic_list)
1832 .handle_in(&mut self.diags);
1833
1834 self.unify(&TypedExpression::Id(pattern.id), value, ctx)
1835 .into_diagnostic(
1836 pattern.loc(),
1837 error_pattern_type_mismatch(
1838 ty.as_ref().map(|t| t.loc()).unwrap_or_else(|| value.loc()),
1839 self,
1840 ),
1841 self,
1842 )
1843 .handle_in(&mut self.diags);
1844
1845 if let Some(t) = ty {
1846 let tvar = self.type_var_from_hir(t.loc(), t, generic_list)?;
1847 self.unify(&TypedExpression::Id(pattern.id), &tvar, ctx)
1848 .into_default_diagnostic(value.loc(), self)
1849 .handle_in(&mut self.diags);
1850 }
1851
1852 wal_trace
1853 .as_ref()
1854 .map(|wt| self.visit_wal_trace(wt, ctx, generic_list))
1855 .transpose()
1856 .handle_in(&mut self.diags);
1857
1858 Ok(())
1859 }
1860 Statement::Expression(expr) => {
1861 self.visit_expression(expr, ctx, generic_list);
1862 Ok(())
1863 }
1864 Statement::Register(reg) => self.visit_register(reg, ctx, generic_list),
1865 Statement::Declaration(names) => {
1866 for name in names {
1867 let new_type = self.new_generic_type(name.loc());
1868 self.add_equation(TypedExpression::Name(name.clone().inner), new_type);
1869 }
1870 Ok(())
1871 }
1872 Statement::PipelineRegMarker(extra) => {
1873 match extra {
1874 Some(PipelineRegMarkerExtra::Condition(cond)) => {
1875 self.visit_expression(cond, ctx, generic_list);
1876 cond.unify_with(&self.t_bool(cond.loc(), ctx.symtab), self)
1877 .commit(self, ctx)
1878 .into_default_diagnostic(cond, self)?;
1879 }
1880 Some(PipelineRegMarkerExtra::Count {
1881 count: _,
1882 count_typeexpr_id: _,
1883 }) => {}
1884 None => {}
1885 }
1886
1887 let current_stage_depth = self
1888 .pipeline_state
1889 .clone()
1890 .ok_or_else(|| {
1891 diag_anyhow!(stmt, "Found a pipeline reg marker in a non-pipeline")
1892 })?
1893 .current_stage_depth;
1894
1895 let new_depth = self.new_generic_tlint(stmt.loc());
1896 let offset = match extra {
1897 Some(PipelineRegMarkerExtra::Count {
1898 count,
1899 count_typeexpr_id,
1900 }) => {
1901 let var = self.hir_type_expr_to_var(count, generic_list)?;
1902 self.add_equation(TypedExpression::Id(*count_typeexpr_id), var.clone());
1903 var
1904 }
1905 Some(PipelineRegMarkerExtra::Condition(_)) | None => self.add_type_var(
1906 TypeVar::Known(stmt.loc(), KnownType::Integer(1.to_bigint()), vec![]),
1907 ),
1908 };
1909
1910 let total_depth = ConstraintExpr::Sum(
1911 Box::new(ConstraintExpr::Var(offset)),
1912 Box::new(ConstraintExpr::Var(current_stage_depth)),
1913 );
1914 self.pipeline_state
1915 .as_mut()
1916 .expect("Expected to have a pipeline state")
1917 .current_stage_depth = new_depth.clone();
1918
1919 self.add_constraint(
1920 new_depth.clone(),
1921 total_depth,
1922 stmt.loc(),
1923 &new_depth,
1924 ConstraintSource::PipelineRegCount {
1925 reg: stmt.loc(),
1926 total: self.get_pipeline_state(stmt)?.total_depth.loc(),
1927 },
1928 );
1929
1930 Ok(())
1931 }
1932 Statement::Label(name) => {
1933 let key = TypedExpression::Name(name.inner.clone());
1934 let var = if !self.equations.contains_key(&key) {
1935 let var = self.new_generic_tlint(name.loc());
1936 self.trace_stack.push(TraceStackEntry::AddingPipelineLabel(
1937 name.inner.clone(),
1938 var.debug_resolve(self),
1939 ));
1940 self.add_equation(key.clone(), var.clone());
1941 var
1942 } else {
1943 let var = self.equations.get(&key).unwrap().clone();
1944 self.trace_stack
1945 .push(TraceStackEntry::RecoveringPipelineLabel(
1946 name.inner.clone(),
1947 var.debug_resolve(self),
1948 ));
1949 var
1950 };
1951 self.unify(
1953 &var,
1954 &self.get_pipeline_state(name)?.current_stage_depth.clone(),
1955 ctx,
1956 )
1957 .unwrap();
1958 Ok(())
1959 }
1960 Statement::WalSuffixed { .. } => Ok(()),
1961 Statement::Assert(expr) => {
1962 self.visit_expression(expr, ctx, generic_list);
1963
1964 expr.unify_with(&self.t_bool(stmt.loc(), ctx.symtab), self)
1965 .commit(self, ctx)
1966 .into_default_diagnostic(expr, self)
1967 .handle_in(&mut self.diags);
1968 Ok(())
1969 }
1970 Statement::Set { target, value } => {
1971 self.visit_expression(target, ctx, generic_list);
1972 self.visit_expression(value, ctx, generic_list);
1973
1974 let inner_type = self.new_generic_type(value.loc());
1975 let outer_type = TypeVar::inverted(stmt.loc(), inner_type.clone()).insert(self);
1976 self.unify_expression_generic_error(target, &outer_type, ctx)
1977 .handle_in(&mut self.diags);
1978 self.unify_expression_generic_error(value, &inner_type, ctx)
1979 .handle_in(&mut self.diags);
1980
1981 Ok(())
1982 }
1983 }
1984 }
1985
1986 pub fn visit_statement(
1987 &mut self,
1988 stmt: &Loc<Statement>,
1989 ctx: &Context,
1990 generic_list: &GenericListToken,
1991 ) {
1992 if let Err(e) = self.visit_statement_error(stmt, ctx, generic_list) {
1993 self.diags.errors.push(e);
1994 }
1995 }
1996
1997 #[trace_typechecker]
1998 pub fn visit_register(
1999 &mut self,
2000 reg: &Register,
2001 ctx: &Context,
2002 generic_list: &GenericListToken,
2003 ) -> Result<()> {
2004 self.visit_pattern(®.pattern, ctx, generic_list)?;
2005
2006 let type_spec_type = match ®.value_type {
2007 Some(t) => Some(self.type_var_from_hir(t.loc(), t, generic_list)?.at_loc(t)),
2008 None => None,
2009 };
2010
2011 if let Some(tvar) = &type_spec_type {
2014 self.unify(&TypedExpression::Id(reg.pattern.id), tvar, ctx)
2015 .into_diagnostic_no_expected_source(
2016 reg.pattern.loc(),
2017 error_pattern_type_mismatch(tvar.loc(), self),
2018 self,
2019 )?;
2020 }
2021
2022 self.visit_expression(®.clock, ctx, generic_list);
2023 self.visit_expression(®.value, ctx, generic_list);
2024
2025 if let Some(tvar) = &type_spec_type {
2026 self.unify(®.value, tvar, ctx)
2027 .into_default_diagnostic(reg.value.loc(), self)?;
2028 }
2029
2030 if let Some((rst_cond, rst_value)) = ®.reset {
2031 self.visit_expression(rst_cond, ctx, generic_list);
2032 self.visit_expression(rst_value, ctx, generic_list);
2033 rst_cond
2035 .unify_with(&self.t_bool(rst_cond.loc(), ctx.symtab), self)
2036 .commit(self, ctx)
2037 .into_diagnostic(
2038 rst_cond.loc(),
2039 |diag,
2040 Tm {
2041 g: got,
2042 e: _expected,
2043 }| {
2044 diag.message(format!(
2045 "Register reset condition must be a bool, got {got}",
2046 got = got.display(self)
2047 ))
2048 .primary_label("expected bool")
2049 },
2050 self,
2051 )?;
2052
2053 self.unify(&rst_value.inner, ®.value.inner, ctx)
2055 .into_diagnostic(
2056 rst_value.loc(),
2057 |diag, tm| {
2058 let (expected, got) = tm.display_e_g(self);
2059 diag.message(format!(
2060 "Register reset value mismatch. Expected {expected} got {got}"
2061 ))
2062 .primary_label(format!("expected {expected}"))
2063 .secondary_label(®.pattern, format!("because this has type {expected}"))
2064 },
2065 self,
2066 )?;
2067 }
2068
2069 if let Some(initial) = ®.initial {
2070 self.visit_expression(initial, ctx, generic_list);
2071
2072 self.unify(&initial.inner, ®.value.inner, ctx)
2073 .into_diagnostic(
2074 initial.loc(),
2075 |diag, tm| {
2076 let (expected, got) = tm.display_e_g(self);
2077 diag.message(format!(
2078 "Register initial value mismatch. Expected {expected} got {got}"
2079 ))
2080 .primary_label(format!("expected {expected}, got {got}"))
2081 .secondary_label(®.pattern, format!("because this has type {got}"))
2082 },
2083 self,
2084 )?;
2085 }
2086
2087 reg.clock
2088 .unify_with(&self.t_clock(reg.clock.loc(), ctx.symtab), self)
2089 .commit(self, ctx)
2090 .into_diagnostic(
2091 reg.clock.loc(),
2092 |diag,
2093 Tm {
2094 g: got,
2095 e: _expected,
2096 }| {
2097 diag.message(format!(
2098 "Expected clock, got {got}",
2099 got = got.display(self)
2100 ))
2101 .primary_label("expected clock")
2102 },
2103 self,
2104 )?;
2105
2106 self.unify(&TypedExpression::Id(reg.pattern.id), ®.value, ctx)
2107 .into_diagnostic(
2108 reg.pattern.loc(),
2109 error_pattern_type_mismatch(reg.value.loc(), self),
2110 self,
2111 )?;
2112
2113 Ok(())
2114 }
2115
2116 #[trace_typechecker]
2117 pub fn visit_trait_spec(
2118 &mut self,
2119 trait_spec: &Loc<TraitSpec>,
2120 generic_list: &GenericListToken,
2121 ) -> Result<Loc<TraitReq>> {
2122 let type_params = if let Some(type_params) = &trait_spec.inner.type_params {
2123 type_params
2124 .inner
2125 .iter()
2126 .map(|te| self.hir_type_expr_to_var(te, generic_list))
2127 .collect::<Result<_>>()?
2128 } else {
2129 vec![]
2130 };
2131
2132 Ok(TraitReq {
2133 name: trait_spec.name.clone(),
2134 type_params,
2135 }
2136 .at_loc(trait_spec))
2137 }
2138
2139 #[trace_typechecker]
2140 pub fn visit_trait_bounds(
2141 &mut self,
2142 target: &Loc<NameID>,
2143 traits: &[Loc<TraitSpec>],
2144 generic_list_tok: &GenericListToken,
2145 ) -> Result<()> {
2146 let trait_reqs = traits
2147 .iter()
2148 .map(|spec| self.visit_trait_spec(spec, generic_list_tok))
2149 .collect::<Result<BTreeSet<_>>>()?
2150 .into_iter()
2151 .collect_vec();
2152
2153 if !trait_reqs.is_empty() {
2154 let trait_list = TraitList::from_vec(trait_reqs);
2155
2156 let generic_list = self.generic_lists.get(generic_list_tok).unwrap();
2157
2158 let Some(tvar) = generic_list.get(&target.inner) else {
2159 return Err(Diagnostic::bug(
2160 target,
2161 "Couldn't find generic from where clause in generic list",
2162 )
2163 .primary_label(format!(
2164 "Generic type {} not found in generic list",
2165 target.inner
2166 )));
2167 };
2168
2169 self.trace_stack.push(TraceStackEntry::AddingTraitBounds(
2170 tvar.debug_resolve(self),
2171 trait_list.clone(),
2172 ));
2173
2174 match tvar.resolve(self) {
2175 TypeVar::Known(_, _, _) => {
2176 }
2182 TypeVar::Unknown(loc, id, old_trait_list, _meta_type) => {
2183 let new_tvar = self.add_type_var(TypeVar::Unknown(
2184 *loc,
2185 *id,
2186 old_trait_list.clone().extend(trait_list),
2187 MetaType::Type,
2188 ));
2189
2190 trace!(
2191 "Adding trait bound {} on type {}",
2192 new_tvar.display_with_meta(true, self),
2193 target.inner
2194 );
2195
2196 let generic_list = self.generic_lists.get_mut(generic_list_tok).unwrap();
2197 generic_list.insert(target.inner.clone(), new_tvar);
2198 }
2199 }
2200 }
2201
2202 Ok(())
2203 }
2204
2205 pub fn visit_const_generic_with_id(
2206 &mut self,
2207 gen: &Loc<ConstGenericWithId>,
2208 generic_list_token: &GenericListToken,
2209 constraint_source: ConstraintSource,
2210 ctx: &Context,
2211 ) -> Result<TypeVarID> {
2212 let var = match &gen.inner.inner {
2213 ConstGeneric::Name(name) => {
2214 let ty = &ctx.symtab.type_symbol_by_id(&name);
2215 match &ty.inner {
2216 TypeSymbol::Declared(_, _) => {
2217 return Err(Diagnostic::error(
2218 name,
2219 "{type_decl_kind} cannot be used in a const generic expression",
2220 )
2221 .primary_label("Type in const generic")
2222 .secondary_label(ty, "{name} is declared here"))
2223 }
2224 TypeSymbol::GenericArg { .. } | TypeSymbol::GenericMeta(MetaType::Type) => {
2225 return Err(Diagnostic::error(
2226 name,
2227 "Generic types cannot be used in const generic expressions",
2228 )
2229 .primary_label("Type in const generic")
2230 .secondary_label(ty, "{name} is declared here")
2231 .span_suggest_insert_before(
2232 "Consider making this a value",
2233 ty.loc(),
2234 "#uint ",
2235 ))
2236 }
2237 TypeSymbol::GenericMeta(MetaType::Number) => {
2238 self.new_generic_tlnumber(gen.loc())
2239 }
2240 TypeSymbol::GenericMeta(MetaType::Int) => self.new_generic_tlint(gen.loc()),
2241 TypeSymbol::GenericMeta(MetaType::Uint) => self.new_generic_tluint(gen.loc()),
2242 TypeSymbol::GenericMeta(MetaType::Bool) => self.new_generic_tlbool(gen.loc()),
2243 TypeSymbol::GenericMeta(MetaType::Str) => self.new_generic_tlstr(gen.loc()),
2244 TypeSymbol::GenericMeta(MetaType::Any) => {
2245 diag_bail!(gen, "Found any meta type")
2246 }
2247 TypeSymbol::Alias(_) => {
2248 return Err(Diagnostic::error(
2249 gen,
2250 "Aliases are not currently supported in const generics",
2251 )
2252 .secondary_label(ty, "Alias defined here"))
2253 }
2254 }
2255 }
2256 ConstGeneric::Int(_)
2257 | ConstGeneric::Add(_, _)
2258 | ConstGeneric::Sub(_, _)
2259 | ConstGeneric::Mul(_, _)
2260 | ConstGeneric::Div(_, _)
2261 | ConstGeneric::Mod(_, _)
2262 | ConstGeneric::UintBitsToFit(_) => self.new_generic_tlnumber(gen.loc()),
2263 ConstGeneric::Str(_) => self.new_generic_tlstr(gen.loc()),
2264 ConstGeneric::Eq(_, _) | ConstGeneric::NotEq(_, _) => {
2265 self.new_generic_tlbool(gen.loc())
2266 }
2267 };
2268 let constraint = self.visit_const_generic(&gen.inner.inner, generic_list_token)?;
2269 self.add_equation(TypedExpression::Id(gen.id), var.clone());
2270 self.add_constraint(var.clone(), constraint, gen.loc(), &var, constraint_source);
2271 Ok(var)
2272 }
2273
2274 #[trace_typechecker]
2275 pub fn visit_const_generic(
2276 &self,
2277 constraint: &ConstGeneric,
2278 generic_list: &GenericListToken,
2279 ) -> Result<ConstraintExpr> {
2280 let wrap = |lhs,
2281 rhs,
2282 wrapper: fn(Box<ConstraintExpr>, Box<ConstraintExpr>) -> ConstraintExpr|
2283 -> Result<_> {
2284 Ok(wrapper(
2285 Box::new(self.visit_const_generic(lhs, generic_list)?),
2286 Box::new(self.visit_const_generic(rhs, generic_list)?),
2287 ))
2288 };
2289 let constraint = match constraint {
2290 ConstGeneric::Name(n) => {
2291 let var = self
2292 .get_generic_list(generic_list)
2293 .ok_or_else(|| diag_anyhow!(n, "Found no generic list"))?
2294 .get(n)
2295 .ok_or_else(|| {
2296 Diagnostic::bug(n, "Found non-generic argument in where clause")
2297 })?;
2298 ConstraintExpr::Var(*var)
2299 }
2300 ConstGeneric::Int(val) => ConstraintExpr::Integer(val.clone()),
2301 ConstGeneric::Str(val) => ConstraintExpr::String(val.clone()),
2302 ConstGeneric::Add(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Sum)?,
2303 ConstGeneric::Sub(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Difference)?,
2304 ConstGeneric::Mul(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Product)?,
2305 ConstGeneric::Div(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Div)?,
2306 ConstGeneric::Mod(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Mod)?,
2307 ConstGeneric::Eq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Eq)?,
2308 ConstGeneric::NotEq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::NotEq)?,
2309 ConstGeneric::UintBitsToFit(a) => ConstraintExpr::UintBitsToRepresent(Box::new(
2310 self.visit_const_generic(a, generic_list)?,
2311 )),
2312 };
2313 Ok(constraint)
2314 }
2315}
2316
2317impl TypeState {
2319 fn new_typeid(&self) -> u64 {
2320 let mut next = self.next_typeid.borrow_mut();
2321 let result = *next;
2322 *next += 1;
2323 result
2324 }
2325
2326 pub fn add_equation(&mut self, expression: TypedExpression, var: TypeVarID) {
2327 self.trace_stack.push(TraceStackEntry::AddingEquation(
2328 expression.clone(),
2329 var.debug_resolve(self),
2330 ));
2331 if let Some(prev) = self.equations.insert(expression.clone(), var.clone()) {
2332 let var = var.clone();
2333 let expr = expression.clone();
2334 println!("{}", format_trace_stack(self));
2335 panic!("Adding equation for {} == {} but a previous eq exists.\n\tIt was previously bound to {}", expr, var.debug_resolve(self), prev.debug_resolve(self))
2336 }
2337 }
2338
2339 fn add_constraint(
2340 &mut self,
2341 lhs: TypeVarID,
2342 rhs: ConstraintExpr,
2343 loc: Loc<()>,
2344 inside: &TypeVarID,
2345 source: ConstraintSource,
2346 ) {
2347 let replaces = lhs.clone();
2348 let rhs = rhs.with_context(&replaces, &inside, source).at_loc(&loc);
2349
2350 self.trace_stack.push(TraceStackEntry::AddingConstraint(
2351 lhs.debug_resolve(self),
2352 rhs.inner.clone(),
2353 ));
2354
2355 self.constraints.add_int_constraint(lhs, rhs);
2356 }
2357
2358 fn add_requirement(&mut self, requirement: Requirement) {
2359 self.trace_stack
2360 .push(TraceStackEntry::AddRequirement(requirement.clone()));
2361 self.requirements.push(requirement)
2362 }
2363
2364 fn unify_inner(
2369 &mut self,
2370 e1: &impl HasType,
2371 e2: &impl HasType,
2372 ctx: &Context,
2373 ) -> std::result::Result<TypeVarID, UnificationError> {
2374 let v1 = e1.get_type(self);
2375 let v2 = e2.get_type(self);
2376
2377 trace!(
2378 "Unifying {} with {}",
2379 v1.debug_resolve(self),
2380 v2.debug_resolve(self)
2381 );
2382
2383 self.trace_stack.push(TraceStackEntry::TryingUnify(
2384 v1.debug_resolve(self),
2385 v2.debug_resolve(self),
2386 ));
2387
2388 macro_rules! err_producer {
2389 () => {{
2390 self.trace_stack
2391 .push(TraceStackEntry::Message("Produced error".to_string()));
2392 UnificationError::Normal(Tm {
2393 g: UnificationTrace::new(v1),
2394 e: UnificationTrace::new(v2),
2395 })
2396 }};
2397 }
2398 macro_rules! meta_err_producer {
2399 () => {{
2400 self.trace_stack
2401 .push(TraceStackEntry::Message("Produced error".to_string()));
2402 UnificationError::MetaMismatch(Tm {
2403 g: UnificationTrace::new(v1),
2404 e: UnificationTrace::new(v2),
2405 })
2406 }};
2407 }
2408
2409 macro_rules! unify_if {
2410 ($condition:expr, $new_type:expr, $replaced_type:expr) => {
2411 if $condition {
2412 Ok(($new_type, $replaced_type))
2413 } else {
2414 Err(err_producer!())
2415 }
2416 };
2417 }
2418
2419 let unify_params = |s: &mut Self,
2420 p1: &[TypeVarID],
2421 p2: &[TypeVarID]|
2422 -> std::result::Result<(), UnificationError> {
2423 if p1.len() != p2.len() {
2424 return Err({
2425 s.trace_stack
2426 .push(TraceStackEntry::Message("Produced error".to_string()));
2427 UnificationError::Normal(Tm {
2428 g: UnificationTrace::new(v1),
2429 e: UnificationTrace::new(v2),
2430 })
2431 });
2432 }
2433
2434 for (t1, t2) in p1.iter().zip(p2.iter()) {
2435 match s.unify_inner(t1, t2, ctx) {
2436 Ok(result) => result,
2437 Err(e) => {
2438 s.trace_stack
2439 .push(TraceStackEntry::Message("Adding context".to_string()));
2440 return Err(e).add_context(v1.clone(), v2.clone());
2441 }
2442 };
2443 }
2444 Ok(())
2445 };
2446
2447 let result = match (
2450 &(v1, v1.resolve(self).clone()),
2451 &(v2, v2.resolve(self).clone()),
2452 ) {
2453 ((_, TypeVar::Known(_, KnownType::Error, _)), _) => Ok((v1, vec![v2])),
2454 (_, (_, TypeVar::Known(_, KnownType::Error, _))) => Ok((v2, vec![v1])),
2455 ((_, TypeVar::Known(_, t1, p1)), (_, TypeVar::Known(_, t2, p2))) => {
2456 match (t1, t2) {
2457 (KnownType::Integer(val1), KnownType::Integer(val2)) => {
2458 unify_if!(val1 == val2, v1, vec![])
2463 }
2464 (KnownType::String(val1), KnownType::String(val2)) => {
2465 unify_if!(val1 == val2, v1, vec![])
2467 }
2468 (KnownType::Named(n1), KnownType::Named(n2)) => {
2469 match (
2470 &ctx.symtab.type_symbol_by_id(n1).inner,
2471 &ctx.symtab.type_symbol_by_id(n2).inner,
2472 ) {
2473 (TypeSymbol::Declared(_, _), TypeSymbol::Declared(_, _)) => {
2474 if n1 != n2 {
2475 return Err(err_producer!());
2476 }
2477
2478 let new_ts1 = ctx.symtab.type_symbol_by_id(n1).inner;
2479 let new_ts2 = ctx.symtab.type_symbol_by_id(n2).inner;
2480 unify_params(self, &p1, &p2)?;
2481 unify_if!(new_ts1 == new_ts2, v1, vec![])
2482 }
2483 (TypeSymbol::Declared(_, _), TypeSymbol::GenericArg { traits }) => {
2484 if !traits.is_empty() {
2485 todo!("Implement trait unifictaion");
2486 }
2487 Ok((v1, vec![]))
2488 }
2489 (TypeSymbol::GenericArg { traits }, TypeSymbol::Declared(_, _)) => {
2490 if !traits.is_empty() {
2491 todo!("Implement trait unifictaion");
2492 }
2493 Ok((v2, vec![]))
2494 }
2495 (
2496 TypeSymbol::GenericArg { traits: ltraits },
2497 TypeSymbol::GenericArg { traits: rtraits },
2498 ) => {
2499 if !ltraits.is_empty() || !rtraits.is_empty() {
2500 todo!("Implement trait unifictaion");
2501 }
2502 Ok((v1, vec![]))
2503 }
2504 (TypeSymbol::Declared(_, _), TypeSymbol::GenericMeta(_)) => todo!(),
2505 (TypeSymbol::GenericArg { traits: _ }, TypeSymbol::GenericMeta(_)) => {
2506 todo!()
2507 }
2508 (TypeSymbol::GenericMeta(_), TypeSymbol::Declared(_, _)) => todo!(),
2509 (TypeSymbol::GenericMeta(_), TypeSymbol::GenericArg { traits: _ }) => {
2510 todo!()
2511 }
2512 (TypeSymbol::Alias(_), _) | (_, TypeSymbol::Alias(_)) => {
2513 return Err(UnificationError::Specific(Diagnostic::bug(
2514 ().nowhere(),
2515 "Encountered a raw type alias during unification",
2516 )))
2517 }
2518 (TypeSymbol::GenericMeta(_), TypeSymbol::GenericMeta(_)) => todo!(),
2519 }
2520 }
2521 (KnownType::Array, KnownType::Array)
2522 | (KnownType::Tuple, KnownType::Tuple)
2523 | (KnownType::Wire, KnownType::Wire)
2524 | (KnownType::Inverted, KnownType::Inverted) => {
2525 unify_params(self, &p1, &p2)?;
2529 Ok((v1, vec![]))
2530 }
2531 (_, _) => Err(err_producer!()),
2532 }
2533 }
2534 (
2536 (_, TypeVar::Unknown(loc1, _, traits1, meta1)),
2537 (_, TypeVar::Unknown(loc2, _, traits2, meta2)),
2538 ) => {
2539 let new_loc = if meta1.is_more_concrete_than(meta2) {
2540 loc1
2541 } else {
2542 loc2
2543 };
2544 let new_t = match unify_meta(meta1, meta2) {
2545 Some(meta @ MetaType::Any) | Some(meta @ MetaType::Number) => {
2546 if traits1.inner.is_empty() || traits2.inner.is_empty() {
2547 return Err(UnificationError::Specific(diag_anyhow!(
2548 new_loc,
2549 "Inferred an any meta-type with traits"
2550 )));
2551 }
2552 self.new_generic_with_meta(*loc1, meta)
2553 }
2554 Some(MetaType::Type) => {
2555 let new_trait_names = traits1
2556 .inner
2557 .iter()
2558 .chain(traits2.inner.iter())
2559 .map(|t| t.name.clone())
2560 .collect::<BTreeSet<_>>()
2561 .into_iter()
2562 .collect::<Vec<_>>();
2563
2564 let new_traits = new_trait_names
2565 .iter()
2566 .map(
2567 |name| match (traits1.get_trait(name), traits2.get_trait(name)) {
2568 (Some(req1), Some(req2)) => {
2569 let new_params = req1
2570 .inner
2571 .type_params
2572 .iter()
2573 .zip(req2.inner.type_params.iter())
2574 .map(|(p1, p2)| self.unify(p1, p2, ctx))
2575 .collect::<std::result::Result<_, UnificationError>>(
2576 )?;
2577
2578 Ok(TraitReq {
2579 name: name.clone(),
2580 type_params: new_params,
2581 }
2582 .at_loc(req1))
2583 }
2584 (Some(t), None) => Ok(t.clone()),
2585 (None, Some(t)) => Ok(t.clone()),
2586 (None, None) => panic!("Found a trait but neither side has it"),
2587 },
2588 )
2589 .collect::<std::result::Result<Vec<_>, UnificationError>>()?;
2590
2591 self.new_generic_with_traits(*new_loc, TraitList::from_vec(new_traits))
2592 }
2593 Some(MetaType::Int) => self.new_generic_tlint(*new_loc),
2594 Some(MetaType::Uint) => self.new_generic_tluint(*new_loc),
2595 Some(MetaType::Bool) => self.new_generic_tlbool(*new_loc),
2596 Some(MetaType::Str) => self.new_generic_tlstr(*new_loc),
2597 None => return Err(meta_err_producer!()),
2598 };
2599 Ok((new_t, vec![v1, v2]))
2600 }
2601 (
2602 (otherid, TypeVar::Known(loc, base, params)),
2603 (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2604 )
2605 | (
2606 (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2607 (otherid, TypeVar::Known(loc, base, params)),
2608 ) => {
2609 let trait_is_expected = match (&v1.resolve(self), &v2.resolve(self)) {
2610 (TypeVar::Known(_, _, _), _) => true,
2611 _ => false,
2612 };
2613
2614 let impls = self.ensure_impls(otherid, traits, trait_is_expected, ukloc, ctx)?;
2615
2616 self.trace_stack.push(TraceStackEntry::Message(
2617 "Unifying trait_parameters".to_string(),
2618 ));
2619 let mut new_params = params.clone();
2620 for (trait_impl, trait_req) in impls {
2621 let mut param_map = BTreeMap::new();
2622
2623 for (l, r) in trait_req
2624 .type_params
2625 .iter()
2626 .zip(trait_impl.trait_type_params)
2627 {
2628 let copy = r.make_copy_with_mapping(self, &mut param_map);
2629 self.unify(l, ©, ctx)?;
2630 }
2631
2632 new_params = trait_impl
2633 .target_type_params
2634 .iter()
2635 .zip(new_params)
2636 .map(|(l, r)| {
2637 let copy = l.make_copy_with_mapping(self, &mut param_map);
2638 self.unify(©, &r, ctx).add_context(*ukid, *otherid)
2639 })
2640 .collect::<std::result::Result<_, _>>()?
2641 }
2642
2643 match (base, meta) {
2644 (KnownType::Error, _) => {
2645 unreachable!()
2646 }
2647 (_, MetaType::Any)
2649 | (KnownType::Named(_), MetaType::Type)
2651 | (KnownType::Tuple, MetaType::Type)
2652 | (KnownType::Array, MetaType::Type)
2653 | (KnownType::Wire, MetaType::Type)
2654 | (KnownType::Bool(_), MetaType::Bool)
2655 | (KnownType::String(_), MetaType::Str)
2656 | (KnownType::Inverted, MetaType::Type)
2657 | (KnownType::Integer(_), MetaType::Int)
2659 | (KnownType::Integer(_), MetaType::Number)
2660 => {
2661 let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2662
2663 Ok((new, vec![otherid.clone(), ukid.clone()]))
2664 },
2665 (KnownType::Integer(val), MetaType::Uint)
2667 => {
2668 if val < &0.to_bigint() {
2669 Err(meta_err_producer!())
2670 } else {
2671 let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2672
2673 Ok((new, vec![otherid.clone(), ukid.clone()]))
2674 }
2675 },
2676
2677 (KnownType::Integer(_), MetaType::Type) => Err(meta_err_producer!()),
2679
2680 (_, MetaType::Bool) => Err(meta_err_producer!()),
2682 (KnownType::Bool(_), _) => Err(meta_err_producer!()),
2683
2684 (_, MetaType::Str) => Err(meta_err_producer!()),
2686 (KnownType::String(_), _) => Err(meta_err_producer!()),
2687
2688 (KnownType::Named(_), MetaType::Int | MetaType::Number | MetaType::Uint)
2690 | (KnownType::Tuple, MetaType::Int | MetaType::Uint | MetaType::Number)
2691 | (KnownType::Array, MetaType::Int | MetaType::Uint | MetaType::Number)
2692 | (KnownType::Wire, MetaType::Int | MetaType::Uint | MetaType::Number)
2693 | (KnownType::Inverted, MetaType::Int | MetaType::Uint | MetaType::Number)
2694 => Err(meta_err_producer!())
2695 }
2696 }
2697 };
2698
2699 let (new_type, replaced_types) = result?;
2700
2701 self.trace_stack.push(TraceStackEntry::Unified(
2702 v1.debug_resolve(self),
2703 v2.debug_resolve(self),
2704 new_type.debug_resolve(self),
2705 replaced_types
2706 .iter()
2707 .map(|v| v.debug_resolve(self))
2708 .collect(),
2709 ));
2710
2711 for replaced_type in &replaced_types {
2712 if v1.inner != v2.inner {
2713 let (from, to) = (replaced_type.get_type(self), new_type.get_type(self));
2714 self.replacements.insert(from, to);
2715 if let Err(rec) = self.check_type_for_recursion(to, &mut vec![]) {
2716 let err_t = self.t_err(().nowhere());
2717 self.replacements.insert(to, err_t);
2718 return Err(UnificationError::RecursiveType(rec));
2719 }
2720 }
2721 }
2722
2723 Ok(new_type)
2724 }
2725
2726 pub fn can_unify(&mut self, e1: &impl HasType, e2: &impl HasType, ctx: &Context) -> bool {
2727 self.trace_stack
2728 .push(TraceStackEntry::Enter("Running can_unify".to_string()));
2729 let result = self.do_and_restore(|s| s.unify(e1, e2, ctx)).is_ok();
2730 self.trace_stack.push(TraceStackEntry::Exit);
2731 result
2732 }
2733
2734 #[tracing::instrument(level = "trace", skip_all)]
2735 pub fn unify(
2736 &mut self,
2737 e1: &impl HasType,
2738 e2: &impl HasType,
2739 ctx: &Context,
2740 ) -> std::result::Result<TypeVarID, UnificationError> {
2741 let new_type = self.unify_inner(e1, e2, ctx)?;
2742
2743 loop {
2746 trace!("Updating constraints");
2747 let new_info;
2749 (self.constraints, new_info) = self
2750 .constraints
2751 .clone()
2752 .update_type_level_value_constraints(self);
2753
2754 if new_info.is_empty() {
2755 break;
2756 }
2757
2758 for constraint in new_info {
2759 trace!(
2760 "Constraint replaces {} with {:?}",
2761 constraint.inner.0.display(self),
2762 constraint.inner.1
2763 );
2764
2765 let ((var, replacement), loc) = constraint.split_loc();
2766
2767 self.trace_stack
2768 .push(TraceStackEntry::InferringFromConstraints(
2769 var.debug_resolve(self),
2770 replacement.val.clone(),
2771 ));
2772
2773 let expected_type = self.add_type_var(TypeVar::Known(loc, replacement.val, vec![]));
2775 let result = self.unify_inner(&expected_type.clone().at_loc(&loc), &var, ctx);
2776 let is_meta_error = matches!(result, Err(UnificationError::MetaMismatch { .. }));
2777 match result {
2778 Ok(_) => {}
2779 Err(UnificationError::Normal(Tm { mut e, mut g }))
2780 | Err(UnificationError::MetaMismatch(Tm { mut e, mut g })) => {
2781 e.inside.replace(
2782 replacement
2783 .context
2784 .inside
2785 .replace_inside(var, e.failing, self),
2786 );
2787 g.inside.replace(
2788 replacement
2789 .context
2790 .inside
2791 .replace_inside(var, g.failing, self),
2792 );
2793 return Err(UnificationError::FromConstraints {
2794 got: g,
2795 expected: e,
2796 source: replacement.context.source,
2797 loc,
2798 is_meta_error,
2799 });
2800 }
2801 Err(
2802 e @ UnificationError::FromConstraints { .. }
2803 | e @ UnificationError::Specific { .. }
2804 | e @ UnificationError::RecursiveType(_)
2805 | e @ UnificationError::UnsatisfiedTraits { .. },
2806 ) => return Err(e),
2807 };
2808 }
2809 }
2810
2811 Ok(new_type)
2812 }
2813
2814 fn check_type_for_recursion(
2815 &self,
2816 ty: TypeVarID,
2817 seen: &mut Vec<TypeVarID>,
2818 ) -> std::result::Result<(), String> {
2819 seen.push(ty);
2820 match ty.resolve(self) {
2821 TypeVar::Known(_, base, params) => {
2822 for (i, param) in params.iter().enumerate() {
2823 if seen.contains(param) {
2824 return Err("*".to_string());
2825 }
2826
2827 if let Err(rest) = self.check_type_for_recursion(*param, seen) {
2828 let list = params
2829 .iter()
2830 .enumerate()
2831 .map(|(j, _)| {
2832 if j == i {
2833 rest.clone()
2834 } else {
2835 "_".to_string()
2836 }
2837 })
2838 .join(", ");
2839
2840 match base {
2841 KnownType::Error => {}
2842 KnownType::Named(name_id) => {
2843 return Err(format!("{name_id}<{}>", list));
2844 }
2845 KnownType::Bool(_) | KnownType::Integer(_) | KnownType::String(_) => {
2846 unreachable!("Encountered recursive type level bool, int or str")
2847 }
2848 KnownType::Tuple => return Err(format!("({})", list)),
2849 KnownType::Array => return Err(format!("[{}]", list)),
2850 KnownType::Wire => return Err(format!("&{}", list)),
2851 KnownType::Inverted => return Err(format!("inv {}", list)),
2852 }
2853 }
2854 }
2855 }
2856 TypeVar::Unknown(_, _, traits, _) => {
2857 for t in &traits.inner {
2858 for param in &t.type_params {
2859 if seen.contains(param) {
2860 return Err("...".to_string());
2861 }
2862
2863 self.check_type_for_recursion(*param, seen)?;
2864 }
2865 }
2866 }
2867 }
2868 seen.pop();
2869 Ok(())
2870 }
2871
2872 fn ensure_impls(
2873 &mut self,
2874 var: &TypeVarID,
2875 traits: &TraitList,
2876 trait_is_expected: bool,
2877 trait_list_loc: &Loc<()>,
2878 ctx: &Context,
2879 ) -> std::result::Result<Vec<(TraitImpl, TraitReq)>, UnificationError> {
2880 self.trace_stack.push(TraceStackEntry::EnsuringImpls(
2881 var.debug_resolve(self),
2882 traits.clone(),
2883 trait_is_expected,
2884 ));
2885
2886 let number = ctx
2887 .symtab
2888 .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
2889 .expect("Did not find number in symtab")
2890 .0;
2891
2892 macro_rules! error_producer {
2893 ($required_traits:expr) => {
2894 if trait_is_expected {
2895 if $required_traits.inner.len() == 1
2896 && $required_traits
2897 .get_trait(&TraitName::Named(number.clone().nowhere()))
2898 .is_some()
2899 {
2900 Err(UnificationError::Normal(Tm {
2901 e: UnificationTrace::new(
2902 self.new_generic_with_traits(*trait_list_loc, $required_traits),
2903 ),
2904 g: UnificationTrace::new(var.clone()),
2905 }))
2906 } else {
2907 Err(UnificationError::UnsatisfiedTraits {
2908 var: *var,
2909 traits: $required_traits.inner,
2910 target_loc: trait_list_loc.clone(),
2911 })
2912 }
2913 } else {
2914 Err(UnificationError::Normal(Tm {
2915 e: UnificationTrace::new(var.clone()),
2916 g: UnificationTrace::new(
2917 self.new_generic_with_traits(*trait_list_loc, $required_traits),
2918 ),
2919 }))
2920 }
2921 };
2922 }
2923
2924 match &var.resolve(self).clone() {
2925 TypeVar::Known(_, known, params) if known.into_impl_target().is_some() => {
2926 let Some(target) = known.into_impl_target() else {
2927 unreachable!()
2928 };
2929
2930 let (impls, unsatisfied): (Vec<_>, Vec<_>) = traits
2931 .inner
2932 .iter()
2933 .map(|trait_req| {
2934 if let Some(impld) = self.trait_impls.inner.get(&target).cloned() {
2935 let target_impls = impld
2938 .iter()
2939 .filter_map(|trait_impl| {
2940 self.checkpoint();
2941 let trait_params_match = trait_impl
2942 .trait_type_params
2943 .iter()
2944 .zip(trait_req.type_params.iter())
2945 .all(|(l, r)| {
2946 let l = l.make_copy(self);
2947 self.unify(&l, r, ctx).is_ok()
2948 });
2949
2950 let impl_params_match =
2951 trait_impl.target_type_params.iter().zip(params).all(
2952 |(l, r)| {
2953 let l = l.make_copy(self);
2954 self.unify(&l, r, ctx).is_ok()
2955 },
2956 );
2957 self.restore();
2958
2959 if trait_impl.name == trait_req.name
2960 && trait_params_match
2961 && impl_params_match
2962 {
2963 Some(trait_impl)
2964 } else {
2965 None
2966 }
2967 })
2968 .collect::<Vec<_>>();
2969
2970 if target_impls.len() == 0 {
2971 Ok(Either::Right(trait_req.clone()))
2972 } else if target_impls.len() == 1 {
2973 let target_impl = *target_impls.last().unwrap();
2974 Ok(Either::Left((target_impl.clone(), trait_req.inner.clone())))
2975 } else {
2976 Err(UnificationError::Specific(diag_anyhow!(
2977 trait_req,
2978 "Found more than one impl of {} for {}",
2979 trait_req.display(self),
2980 var.display(self)
2981 )))
2982 }
2983 } else {
2984 Ok(Either::Right(trait_req.clone()))
2985 }
2986 })
2987 .collect::<std::result::Result<Vec<_>, _>>()?
2988 .into_iter()
2989 .partition_map(|x| x);
2990
2991 if unsatisfied.is_empty() {
2992 self.trace_stack.push(TraceStackEntry::Message(
2993 "Ensuring impl successful".to_string(),
2994 ));
2995 Ok(impls)
2996 } else {
2997 error_producer!(TraitList::from_vec(unsatisfied.clone()))
2998 }
2999 }
3000 TypeVar::Unknown(_, _, _, _) => {
3001 panic!("running ensure_impls on an unknown type")
3002 }
3003 _ => {
3004 if traits.inner.is_empty() {
3005 Ok(vec![])
3006 } else {
3007 error_producer!(traits.clone())
3008 }
3009 }
3010 }
3011 }
3012
3013 pub fn unify_expression_generic_error(
3014 &mut self,
3015 expr: &Loc<Expression>,
3016 other: &impl HasType,
3017 ctx: &Context,
3018 ) -> Result<TypeVarID> {
3019 self.unify(&expr.inner, other, ctx)
3020 .into_default_diagnostic(expr.loc(), self)
3021 }
3022
3023 pub fn check_requirements(&mut self, is_final_check: bool, ctx: &Context) -> Result<()> {
3024 loop {
3026 let (retain, replacements_option): (Vec<_>, Vec<_>) = self
3030 .requirements
3031 .clone()
3032 .iter()
3033 .map(|req| match req.check(self, ctx)? {
3034 requirements::RequirementResult::NoChange => Ok((true, None)),
3035 requirements::RequirementResult::UnsatisfiedNow(diag) => {
3036 if is_final_check {
3037 Err(diag)
3038 } else {
3039 Ok((true, None))
3040 }
3041 }
3042 requirements::RequirementResult::Satisfied(replacement) => {
3043 self.trace_stack
3044 .push(TraceStackEntry::ResolvedRequirement(req.clone()));
3045 Ok((false, Some(replacement)))
3046 }
3047 })
3048 .collect::<Result<Vec<_>>>()?
3049 .into_iter()
3050 .unzip();
3051
3052 let replacements = replacements_option
3053 .into_iter()
3054 .flatten()
3055 .flatten()
3056 .collect::<Vec<_>>();
3057
3058 self.requirements = self
3060 .requirements
3061 .drain(0..)
3062 .zip(retain)
3063 .filter_map(|(req, keep)| if keep { Some(req) } else { None })
3064 .collect();
3065
3066 if replacements.is_empty() {
3067 break;
3068 }
3069
3070 for Replacement { from, to, context } in replacements {
3071 self.unify(&to, &from, ctx).into_diagnostic_or_default(
3072 from.loc(),
3073 context,
3074 self,
3075 )?;
3076 }
3077 }
3078
3079 Ok(())
3080 }
3081
3082 pub fn get_replacement(&self, var: &TypeVarID) -> TypeVarID {
3083 self.replacements.get(*var)
3084 }
3085
3086 pub fn do_and_restore<T, E>(
3087 &mut self,
3088 inner: impl Fn(&mut Self) -> std::result::Result<T, E>,
3089 ) -> std::result::Result<T, E> {
3090 self.checkpoint();
3091 let result = inner(self);
3092 self.restore();
3093 result
3094 }
3095
3096 fn checkpoint(&mut self) {
3099 self.trace_stack
3100 .push(TraceStackEntry::Enter("Creating checkpoint".to_string()));
3101 self.replacements.push();
3102
3103 self.checkpoints
3107 .push((self.requirements.clone(), self.constraints.clone()));
3108 }
3109
3110 fn restore(&mut self) {
3111 self.replacements.discard_top();
3112 self.trace_stack.push(TraceStackEntry::Exit);
3113
3114 let (requirements, constraints) = self
3115 .checkpoints
3116 .pop()
3117 .expect("Popped a checkpoint without any existing checkpoints.");
3118
3119 self.requirements = requirements;
3120 self.constraints = constraints;
3121 }
3122}
3123
3124impl TypeState {
3125 pub fn print_equations(&self) {
3126 for (lhs, rhs) in &self.equations {
3127 println!(
3128 "{} -> {}",
3129 format!("{lhs}").blue(),
3130 format!("{}", rhs.debug_resolve(self)).green()
3131 )
3132 }
3133
3134 println!("\nReplacments:");
3135
3136 for repl_stack in &self.replacements.all() {
3137 let replacements = { repl_stack.borrow().clone() };
3138 for (lhs, rhs) in replacements.iter().sorted() {
3139 println!(
3140 "{} -> {} ({} -> {})",
3141 format!("{}", lhs.inner).blue(),
3142 format!("{}", rhs.inner).green(),
3143 format!("{}", lhs.debug_resolve(self)).blue(),
3144 format!("{}", rhs.debug_resolve(self)).green(),
3145 )
3146 }
3147 println!("---")
3148 }
3149
3150 println!("\n Requirements:");
3151
3152 for requirement in &self.requirements {
3153 println!("{:?}", requirement)
3154 }
3155
3156 println!()
3157 }
3158}
3159
3160#[must_use]
3161pub struct UnificationBuilder {
3162 lhs: TypeVarID,
3163 rhs: TypeVarID,
3164}
3165impl UnificationBuilder {
3166 pub fn commit(
3167 self,
3168 state: &mut TypeState,
3169 ctx: &Context,
3170 ) -> std::result::Result<TypeVarID, UnificationError> {
3171 state.unify(&self.lhs, &self.rhs, ctx)
3172 }
3173}
3174
3175pub trait HasType: std::fmt::Debug {
3176 fn get_type(&self, state: &TypeState) -> TypeVarID {
3177 self.try_get_type(state)
3178 .unwrap_or(state.error_type.unwrap())
3179 }
3180
3181 fn try_get_type(&self, state: &TypeState) -> Option<TypeVarID> {
3182 let id = self.get_type_impl(state);
3183 id.map(|id| state.get_replacement(&id))
3184 }
3185
3186 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID>;
3187
3188 fn unify_with(&self, rhs: &dyn HasType, state: &TypeState) -> UnificationBuilder {
3189 UnificationBuilder {
3190 lhs: self.get_type(state),
3191 rhs: rhs.get_type(state),
3192 }
3193 }
3194}
3195
3196impl HasType for TypeVarID {
3197 fn get_type_impl(&self, _state: &TypeState) -> Option<TypeVarID> {
3198 Some(*self)
3199 }
3200}
3201impl HasType for Loc<TypeVarID> {
3202 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3203 self.inner.try_get_type(state)
3204 }
3205}
3206impl HasType for TypedExpression {
3207 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3208 state.maybe_type_of(self).cloned()
3209 }
3210}
3211impl HasType for Expression {
3212 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3213 state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3214 }
3215}
3216impl HasType for Loc<Expression> {
3217 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3218 state
3219 .maybe_type_of(&TypedExpression::Id(self.inner.id))
3220 .cloned()
3221 }
3222}
3223impl HasType for Pattern {
3224 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3225 state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3226 }
3227}
3228impl HasType for Loc<Pattern> {
3229 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3230 state
3231 .maybe_type_of(&TypedExpression::Id(self.inner.id))
3232 .cloned()
3233 }
3234}
3235impl HasType for NameID {
3236 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3237 state
3238 .maybe_type_of(&TypedExpression::Name(self.clone()))
3239 .cloned()
3240 }
3241}