1use std::cell::RefCell;
10use std::cmp::PartialEq;
11use std::collections::{BTreeMap, BTreeSet, HashMap};
12
13use colored::Colorize;
14use fixed_types::{t_int, t_uint};
15use hir::expression::CallKind;
16use hir::{
17 param_util, Binding, ConstGeneric, Parameter, PipelineRegMarkerExtra, TypeExpression, TypeSpec,
18 UnitHead, UnitKind, WalTrace, WhereClause,
19};
20use itertools::{Either, Itertools};
21use method_resolution::{FunctionLikeName, IntoImplTarget};
22use num::{BigInt, BigUint, Zero};
23use replacement::ReplacementStack;
24use serde::{Deserialize, Serialize};
25use spade_common::id_tracker::{ExprID, ImplID};
26use spade_common::num_ext::InfallibleToBigInt;
27use spade_diagnostics::diag_list::{DiagList, ResultExt};
28use spade_diagnostics::{diag_anyhow, diag_bail, Diagnostic};
29use spade_macros::trace_typechecker;
30use spade_types::meta_types::{unify_meta, MetaType};
31use trace_stack::TraceStack;
32use tracing::{info, trace};
33
34use spade_common::location_info::{Loc, WithLocation};
35use spade_common::name::{Identifier, NameID, Path};
36use spade_hir::param_util::{match_args_with_params, Argument};
37use spade_hir::symbol_table::{Patternable, PatternableKind, SymbolTable, TypeSymbol};
38use spade_hir::{self as hir, ConstGenericWithId, ImplTarget};
39use spade_hir::{
40 ArgumentList, Block, ExprKind, Expression, ItemList, Pattern, PatternArgument, Register,
41 Statement, TraitName, TraitSpec, TypeParam, Unit,
42};
43use spade_types::KnownType;
44
45use constraints::{
46 bits_to_store, ce_int, ce_var, ConstraintExpr, ConstraintSource, TypeConstraints,
47};
48use equation::{TemplateTypeVarID, TypeEquations, TypeVar, TypeVarID, TypedExpression};
49use error::{
50 error_pattern_type_mismatch, Result, UnificationError, UnificationErrorExt, UnificationTrace,
51};
52use requirements::{Replacement, Requirement};
53use trace_stack::{format_trace_stack, TraceStackEntry};
54use traits::{TraitList, TraitReq};
55
56use crate::error::TypeMismatch as Tm;
57use crate::requirements::ConstantInt;
58use crate::traits::{TraitImpl, TraitImplList};
59
60mod constraints;
61pub mod dump;
62pub mod equation;
63pub mod error;
64pub mod expression;
65pub mod fixed_types;
66pub mod method_resolution;
67pub mod mir_type_lowering;
68mod replacement;
69mod requirements;
70pub mod testutil;
71pub mod trace_stack;
72pub mod traits;
73
74pub struct Context<'a> {
75 pub symtab: &'a SymbolTable,
76 pub items: &'a ItemList,
77 pub trait_impls: &'a TraitImplList,
78}
79impl<'a> std::fmt::Debug for Context<'a> {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "{{context omitted}}")
82 }
83}
84
85#[allow(unused_macros)]
87macro_rules! add_trace {
88 ($self:expr, $($arg : tt) *) => {
89 $self.trace_stack.push(TraceStack::Message(format!($($arg)*)))
90 }
91}
92
93#[derive(Debug)]
94pub enum GenericListSource<'a> {
95 Anonymous,
98 Definition(&'a NameID),
101 ImplBlock {
102 target: &'a ImplTarget,
103 id: ImplID,
104 },
105 Expression(ExprID),
107}
108
109#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize)]
111pub enum GenericListToken {
112 Anonymous(usize),
113 Definition(NameID),
114 ImplBlock(ImplTarget, ImplID),
115 Expression(ExprID),
116}
117
118#[derive(Debug)]
119pub struct TurbofishCtx<'a> {
120 turbofish: &'a Loc<ArgumentList<TypeExpression>>,
121 prev_generic_list: &'a GenericListToken,
122 type_ctx: &'a Context<'a>,
123}
124
125#[derive(Clone, Serialize, Deserialize)]
126pub struct PipelineState {
127 current_stage_depth: TypeVarID,
128 total_depth: Loc<TypeVarID>,
129 pipeline_loc: Loc<()>,
130}
131
132#[derive(Clone, Serialize, Deserialize)]
134pub struct TypeState {
135 type_vars: Vec<TypeVar>,
140 key: u64,
145 keys: BTreeSet<u64>,
147
148 equations: TypeEquations,
149 next_typeid: RefCell<u64>,
152 generic_lists: HashMap<GenericListToken, HashMap<NameID, TypeVarID>>,
158
159 constraints: TypeConstraints,
160
161 #[serde(skip)]
165 requirements: Vec<Requirement>,
166
167 replacements: ReplacementStack,
168
169 pipeline_state: Option<PipelineState>,
171
172 pub trait_impls: TraitImplList,
173
174 #[serde(skip)]
175 pub trace_stack: TraceStack,
176
177 #[serde(skip)]
178 pub diags: DiagList,
179}
180
181impl TypeState {
182 pub fn fresh() -> Self {
185 let key = fastrand::u64(..);
186 Self {
187 type_vars: vec![],
188 key,
189 keys: [key].into_iter().collect(),
190 equations: HashMap::new(),
191 next_typeid: RefCell::new(0),
192 trace_stack: TraceStack::new(),
193 constraints: TypeConstraints::new(),
194 requirements: vec![],
195 replacements: ReplacementStack::new(),
196 generic_lists: HashMap::new(),
197 trait_impls: TraitImplList::new(),
198 pipeline_state: None,
199 diags: DiagList::new(),
200 }
201 }
202
203 pub fn create_child(&self) -> Self {
204 let mut result = self.clone();
205 result.key = fastrand::u64(..);
206 result.keys.insert(result.key);
207 result
208 }
209
210 pub fn add_type_var(&mut self, var: TypeVar) -> TypeVarID {
211 let idx = self.type_vars.len();
212 self.type_vars.push(var);
213 TypeVarID {
214 inner: idx,
215 type_state_key: self.key,
216 }
217 }
218
219 pub fn get_equations(&self) -> &TypeEquations {
220 &self.equations
221 }
222
223 pub fn get_constraints(&self) -> &TypeConstraints {
224 &self.constraints
225 }
226
227 pub fn get_generic_list<'a>(
229 &'a self,
230 generic_list_token: &'a GenericListToken,
231 ) -> Option<&'a HashMap<NameID, TypeVarID>> {
232 self.generic_lists.get(generic_list_token)
233 }
234
235 #[tracing::instrument(level = "trace", skip_all)]
236 fn hir_type_expr_to_var<'a>(
237 &'a mut self,
238 e: &Loc<hir::TypeExpression>,
239 generic_list_token: &GenericListToken,
240 ) -> Result<TypeVarID> {
241 let id = match &e.inner {
242 hir::TypeExpression::Integer(i) => self.add_type_var(TypeVar::Known(
243 e.loc(),
244 KnownType::Integer(i.clone()),
245 vec![],
246 )),
247 hir::TypeExpression::TypeSpec(spec) => {
248 self.type_var_from_hir(e.loc(), &spec.clone(), generic_list_token)?
249 }
250 hir::TypeExpression::ConstGeneric(g) => {
251 let constraint = self.visit_const_generic(g, generic_list_token)?;
252
253 let tvar = self.new_generic_tlnumber(e.loc());
254 self.add_constraint(
255 tvar.clone(),
256 constraint,
257 g.loc(),
258 &tvar,
259 ConstraintSource::Where,
260 );
261
262 tvar
263 }
264 };
265 Ok(id)
266 }
267
268 #[tracing::instrument(level = "trace", skip_all, fields(%hir_type))]
269 pub fn type_var_from_hir<'a>(
270 &'a mut self,
271 loc: Loc<()>,
272 hir_type: &crate::hir::TypeSpec,
273 generic_list_token: &GenericListToken,
274 ) -> Result<TypeVarID> {
275 let generic_list = self.get_generic_list(generic_list_token);
276 let var = match &hir_type {
277 hir::TypeSpec::Declared(base, params) => {
278 let params = params
279 .iter()
280 .map(|e| self.hir_type_expr_to_var(e, generic_list_token))
281 .collect::<Result<Vec<_>>>()?;
282
283 self.add_type_var(TypeVar::Known(
284 loc,
285 KnownType::Named(base.inner.clone()),
286 params,
287 ))
288 }
289 hir::TypeSpec::Generic(name) => match generic_list
290 .ok_or_else(|| diag_anyhow!(loc, "Found no generic list for {name}"))?
291 .get(&name.inner)
292 {
293 Some(t) => t.clone(),
294 None => {
295 for list_source in self.generic_lists.keys() {
296 info!("Generic lists exist for {list_source:?}");
297 }
298 info!("Current source is {generic_list_token:?}");
299 panic!("No entry in generic list for {name:?}");
300 }
301 },
302 hir::TypeSpec::Tuple(inner) => {
303 let inner = inner
304 .iter()
305 .map(|t| self.type_var_from_hir(loc, t, generic_list_token))
306 .collect::<Result<_>>()?;
307 self.add_type_var(TypeVar::tuple(loc, inner))
308 }
309 hir::TypeSpec::Array { inner, size } => {
310 let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
311 let size = self.hir_type_expr_to_var(size, generic_list_token)?;
312
313 self.add_type_var(TypeVar::array(loc, inner, size))
314 }
315 hir::TypeSpec::Wire(inner) => {
316 let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
317 self.add_type_var(TypeVar::wire(loc, inner))
318 }
319 hir::TypeSpec::Inverted(inner) => {
320 let inner = self.type_var_from_hir(loc, inner, generic_list_token)?;
321 self.add_type_var(TypeVar::inverted(loc, inner))
322 }
323 hir::TypeSpec::Wildcard(_) => self.new_generic_any(),
324 hir::TypeSpec::TraitSelf(_) => {
325 diag_bail!(
326 loc,
327 "Trying to convert TraitSelf to type inference type var"
328 )
329 }
330 };
331
332 Ok(var)
333 }
334
335 pub fn type_of(&self, expr: &TypedExpression) -> TypeVarID {
338 if let Some(t) = self.equations.get(expr) {
339 *t
340 } else {
341 panic!("Tried looking up the type of {expr:?} but it was not found")
342 }
343 }
344
345 pub fn maybe_type_of(&self, expr: &TypedExpression) -> Option<&TypeVarID> {
346 self.equations.get(expr)
347 }
348
349 pub fn new_generic_int(&mut self, loc: Loc<()>, symtab: &SymbolTable) -> TypeVar {
350 TypeVar::Known(loc, t_int(symtab), vec![self.new_generic_tluint(loc)])
351 }
352
353 pub fn new_concrete_int(&mut self, size: BigUint, loc: Loc<()>) -> TypeVarID {
354 TypeVar::Known(loc, KnownType::Integer(size.to_bigint()), vec![]).insert(self)
355 }
356
357 pub fn new_split_generic_int(
360 &mut self,
361 loc: Loc<()>,
362 symtab: &SymbolTable,
363 ) -> (TypeVarID, TypeVarID) {
364 let size = self.new_generic_tlint(loc);
365 let full = self.add_type_var(TypeVar::Known(loc, t_int(symtab), vec![size.clone()]));
366 (full, size)
367 }
368
369 pub fn new_split_generic_uint(
370 &mut self,
371 loc: Loc<()>,
372 symtab: &SymbolTable,
373 ) -> (TypeVarID, TypeVarID) {
374 let size = self.new_generic_tluint(loc);
375 let full = self.add_type_var(TypeVar::Known(loc, t_uint(symtab), vec![size.clone()]));
376 (full, size)
377 }
378
379 pub fn new_generic_with_meta(&mut self, loc: Loc<()>, meta: MetaType) -> TypeVarID {
380 let id = self.new_typeid();
381 self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), meta))
382 }
383
384 pub fn new_generic_type(&mut self, loc: Loc<()>) -> TypeVarID {
385 let id = self.new_typeid();
386 self.add_type_var(TypeVar::Unknown(
387 loc,
388 id,
389 TraitList::empty(),
390 MetaType::Type,
391 ))
392 }
393
394 pub fn new_generic_any(&mut self) -> TypeVarID {
395 let id = self.new_typeid();
396 self.add_type_var(TypeVar::Unknown(
399 ().nowhere(),
400 id,
401 TraitList::empty(),
402 MetaType::Any,
403 ))
404 }
405
406 pub fn new_generic_tlbool(&mut self, loc: Loc<()>) -> TypeVarID {
407 let id = self.new_typeid();
408 self.add_type_var(TypeVar::Unknown(
409 loc,
410 id,
411 TraitList::empty(),
412 MetaType::Bool,
413 ))
414 }
415
416 pub fn new_generic_tluint(&mut self, loc: Loc<()>) -> TypeVarID {
417 let id = self.new_typeid();
418 self.add_type_var(TypeVar::Unknown(
419 loc,
420 id,
421 TraitList::empty(),
422 MetaType::Uint,
423 ))
424 }
425
426 pub fn new_generic_tlint(&mut self, loc: Loc<()>) -> TypeVarID {
427 let id = self.new_typeid();
428 self.add_type_var(TypeVar::Unknown(loc, id, TraitList::empty(), MetaType::Int))
429 }
430
431 pub fn new_generic_tlnumber(&mut self, loc: Loc<()>) -> TypeVarID {
432 let id = self.new_typeid();
433 self.add_type_var(TypeVar::Unknown(
434 loc,
435 id,
436 TraitList::empty(),
437 MetaType::Number,
438 ))
439 }
440
441 pub fn new_generic_number(&mut self, loc: Loc<()>, ctx: &Context) -> (TypeVarID, TypeVarID) {
442 let number = ctx
443 .symtab
444 .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
445 .expect("Did not find number in symtab")
446 .0;
447 let id = self.new_typeid();
448 let size = self.new_generic_tluint(loc);
449 let t = TraitReq {
450 name: TraitName::Named(number.nowhere()),
451 type_params: vec![size.clone()],
452 }
453 .nowhere();
454 (
455 self.add_type_var(TypeVar::Unknown(
456 loc,
457 id,
458 TraitList::from_vec(vec![t]),
459 MetaType::Type,
460 )),
461 size,
462 )
463 }
464
465 pub fn new_generic_with_traits(&mut self, loc: Loc<()>, traits: TraitList) -> TypeVarID {
466 let id = self.new_typeid();
467 self.add_type_var(TypeVar::Unknown(loc, id, traits, MetaType::Type))
468 }
469
470 pub fn get_pipeline_state<T>(&self, access_loc: &Loc<T>) -> Result<&PipelineState> {
474 self.pipeline_state
475 .as_ref()
476 .ok_or_else(|| diag_anyhow!(access_loc, "Expected to have a pipeline state"))
477 }
478
479 pub fn visit_unit_with_preprocessing(
481 &mut self,
482 entity: &Loc<Unit>,
483 pp: impl Fn(&mut TypeState, &Loc<Unit>, &GenericListToken, &Context) -> Result<()>,
484 ctx: &Context,
485 ) -> Result<()> {
486 self.trait_impls = ctx.trait_impls.clone();
487
488 let generic_list = self.create_generic_list(
489 GenericListSource::Definition(&entity.name.name_id().inner),
490 &entity.head.unit_type_params,
491 &entity.head.scope_type_params,
492 None,
493 &entity.head.where_clauses,
496 )?;
497
498 pp(self, entity, &generic_list, ctx)?;
499
500 for (name, t) in &entity.inputs {
502 let tvar = self.type_var_from_hir(t.loc(), t, &generic_list)?;
503 self.add_equation(TypedExpression::Name(name.inner.clone()), tvar)
504 }
505
506 if let UnitKind::Pipeline {
507 depth,
508 depth_typeexpr_id,
509 } = &entity.head.unit_kind.inner
510 {
511 let depth_var = self.hir_type_expr_to_var(depth, &generic_list)?;
512 self.add_equation(TypedExpression::Id(*depth_typeexpr_id), depth_var.clone());
513 self.pipeline_state = Some(PipelineState {
514 current_stage_depth: self.add_type_var(TypeVar::Known(
515 entity.head.unit_kind.loc(),
516 KnownType::Integer(BigInt::zero()),
517 vec![],
518 )),
519 pipeline_loc: entity.loc(),
520 total_depth: depth_var.clone().at_loc(depth),
521 });
522 self.add_requirement(Requirement::PositivePipelineDepth {
523 depth: depth_var.at_loc(depth),
524 });
525 TypedExpression::Name(entity.inputs[0].0.clone().inner)
526 .unify_with(&self.t_clock(entity.head.unit_kind.loc(), ctx.symtab), self)
527 .commit(self, ctx)
528 .into_diagnostic(
529 entity.inputs[0].1.loc(),
530 |diag,
531 Tm {
532 g: got,
533 e: _expected,
534 }| {
535 diag.message(format!(
536 "First pipeline argument must be a clock. Got {}",
537 got.display(self)
538 ))
539 .primary_label("expected clock")
540 },
541 self,
542 )?;
543 self.check_requirements(false, ctx)?;
546 }
547
548 self.visit_expression(&entity.body, ctx, &generic_list);
549
550 if let Some(output_type) = &entity.head.output_type {
552 let tvar = self.type_var_from_hir(output_type.loc(), output_type, &generic_list)?;
553
554 self.trace_stack.push(TraceStackEntry::Message(format!(
555 "Unifying with output type {}",
556 tvar.debug_resolve(self)
557 )));
558 self.unify(&TypedExpression::Id(entity.body.inner.id), &tvar, ctx)
559 .into_diagnostic_no_expected_source(
560 &entity.body,
561 |diag,
562 Tm {
563 g: got,
564 e: expected,
565 }| {
566 let expected = expected.display(self);
567 let got = got.display(self);
568 diag.message(format!(
571 "Output type mismatch. Expected {expected}, got {got}"
572 ))
573 .primary_label(format!("Found type {got}"))
574 .secondary_label(output_type, format!("{expected} type specified here"))
575 },
576 self,
577 )?;
578 } else {
579 TypedExpression::Id(entity.body.inner.id)
581 .unify_with(&self.add_type_var(TypeVar::unit(entity.head.name.loc())), self)
582 .commit(self, ctx)
583 .into_diagnostic_no_expected_source(entity.body.loc(), |diag, Tm{g: got, e: _expected}| {
584 diag.message("Output type mismatch")
585 .primary_label(format!("Found type {got}", got = got.display(self)))
586 .note(format!(
587 "The {} does not specify a return type.\nAdd a return type, or remove the return value.",
588 entity.head.unit_kind.name()
589 ))
590 }, self)?;
591 }
592
593 if let Some(PipelineState {
594 current_stage_depth,
595 pipeline_loc,
596 total_depth,
597 }) = self.pipeline_state.clone()
598 {
599 self.unify(&total_depth.inner, ¤t_stage_depth, ctx)
600 .into_diagnostic_no_expected_source(
601 pipeline_loc,
602 |diag, tm| {
603 let (e, g) = tm.display_e_g(self);
604 diag.message(format!("Pipeline depth mismatch. Expected {g} got {e}"))
605 .primary_label(format!("Found {e} stages in this pipeline"))
606 },
607 self,
608 )?;
609 }
610
611 self.check_requirements(true, ctx)?;
612
613 self.pipeline_state = None;
618
619 Ok(())
620 }
621
622 #[trace_typechecker]
623 #[tracing::instrument(level = "trace", skip_all, fields(%entity.name))]
624 pub fn visit_unit(&mut self, entity: &Loc<Unit>, ctx: &Context) -> Result<()> {
625 self.visit_unit_with_preprocessing(entity, |_, _, _, _| Ok(()), ctx)
626 }
627
628 #[trace_typechecker]
629 #[tracing::instrument(level = "trace", skip_all)]
630 fn visit_argument_list(
631 &mut self,
632 args: &Loc<ArgumentList<Expression>>,
633 ctx: &Context,
634 generic_list: &GenericListToken,
635 ) -> Result<()> {
636 for expr in args.expressions() {
637 self.visit_expression(expr, ctx, generic_list);
638 }
639 Ok(())
640 }
641
642 #[trace_typechecker]
643 fn type_check_argument_list(
644 &mut self,
645 args: &[Argument<Expression, TypeSpec>],
646 ctx: &Context,
647 generic_list: &GenericListToken,
648 ) -> Result<()> {
649 for Argument {
650 target,
651 target_type,
652 value,
653 kind,
654 } in args.iter()
655 {
656 let target_type = self.type_var_from_hir(value.loc(), target_type, generic_list)?;
657
658 let loc = match kind {
659 hir::param_util::ArgumentKind::Positional => value.loc(),
660 hir::param_util::ArgumentKind::Named => value.loc(),
661 hir::param_util::ArgumentKind::ShortNamed => target.loc(),
662 };
663
664 self.unify(&value.inner, &target_type, ctx)
665 .into_diagnostic(
666 loc,
667 |d, tm| {
668 let (expected, got) = tm.display_e_g(self);
669 d.message(format!(
670 "Argument type mismatch. Expected {expected} got {got}"
671 ))
672 .primary_label(format!("expected {expected}"))
673 },
674 self,
675 )?;
676 }
677
678 Ok(())
679 }
680
681 #[trace_typechecker]
682 pub fn visit_expression_result(
683 &mut self,
684 expression: &Loc<Expression>,
685 ctx: &Context,
686 generic_list: &GenericListToken,
687 new_type: TypeVarID,
688 ) -> Result<()> {
689 match &expression.inner.kind {
691 ExprKind::Error => {
692 new_type
693 .unify_with(&self.t_err(expression.loc()), self)
694 .commit(self, ctx)
695 .unwrap();
696 }
697 ExprKind::Identifier(_) => self.visit_identifier(expression, ctx)?,
698 ExprKind::TypeLevelInteger(_) => {
699 self.visit_type_level_integer(expression, generic_list, ctx)?
700 }
701 ExprKind::IntLiteral(_, _) => self.visit_int_literal(expression, ctx)?,
702 ExprKind::BoolLiteral(_) => self.visit_bool_literal(expression, ctx)?,
703 ExprKind::BitLiteral(_) => self.visit_bit_literal(expression, ctx)?,
704 ExprKind::TupleLiteral(_) => self.visit_tuple_literal(expression, ctx, generic_list)?,
705 ExprKind::TupleIndex(_, _) => self.visit_tuple_index(expression, ctx, generic_list)?,
706 ExprKind::ArrayLiteral(_) => self.visit_array_literal(expression, ctx, generic_list)?,
707 ExprKind::ArrayShorthandLiteral(_, _) => {
708 self.visit_array_shorthand_literal(expression, ctx, generic_list)?
709 }
710 ExprKind::CreatePorts => self.visit_create_ports(expression, ctx, generic_list)?,
711 ExprKind::FieldAccess(_, _) => {
712 self.visit_field_access(expression, ctx, generic_list)?
713 }
714 ExprKind::MethodCall { .. } => self.visit_method_call(expression, ctx, generic_list)?,
715 ExprKind::Index(_, _) => self.visit_index(expression, ctx, generic_list)?,
716 ExprKind::RangeIndex { .. } => self.visit_range_index(expression, ctx, generic_list)?,
717 ExprKind::Block(_) => self.visit_block_expr(expression, ctx, generic_list)?,
718 ExprKind::If(_, _, _) => self.visit_if(expression, ctx, generic_list)?,
719 ExprKind::Match(_, _) => self.visit_match(expression, ctx, generic_list)?,
720 ExprKind::BinaryOperator(_, _, _) => {
721 self.visit_binary_operator(expression, ctx, generic_list)?
722 }
723 ExprKind::UnaryOperator(_, _) => {
724 self.visit_unary_operator(expression, ctx, generic_list)?
725 }
726 ExprKind::Call {
727 kind,
728 callee,
729 args,
730 turbofish,
731 } => {
732 let head = ctx.symtab.unit_by_id(&callee.inner);
733
734 self.handle_function_like(
735 expression.map_ref(|e| e.id),
736 &expression.get_type(self),
737 &FunctionLikeName::Free(callee.inner.clone()),
738 &head,
739 kind,
740 args,
741 ctx,
742 true,
743 false,
744 turbofish.as_ref().map(|turbofish| TurbofishCtx {
745 turbofish,
746 prev_generic_list: generic_list,
747 type_ctx: ctx,
748 }),
749 generic_list,
750 )?;
751 }
752 ExprKind::PipelineRef { .. } => {
753 self.visit_pipeline_ref(expression, generic_list, ctx)?;
754 }
755 ExprKind::StageReady | ExprKind::StageValid => {
756 expression
757 .unify_with(&self.t_bool(expression.loc(), ctx.symtab), self)
758 .commit(self, ctx)
759 .into_default_diagnostic(expression, self)?;
760 }
761
762 ExprKind::TypeLevelIf(cond, on_true, on_false) => {
763 let cond_var = self.visit_const_generic_with_id(
764 cond,
765 generic_list,
766 ConstraintSource::TypeLevelIf,
767 ctx,
768 )?;
769 let t_bool = self.new_generic_tlbool(cond.loc());
770 self.unify(&cond_var, &t_bool, ctx).into_diagnostic(
771 cond,
772 |diag, tm| {
773 let (_e, g) = tm.display_e_g(self);
774 diag.message(format!("gen if conditions must be #bool, got {g}"))
775 },
776 self,
777 )?;
778
779 self.visit_expression(on_true, ctx, generic_list);
780 self.visit_expression(on_false, ctx, generic_list);
781
782 self.unify_expression_generic_error(expression, on_true.as_ref(), ctx)?;
783 self.unify_expression_generic_error(expression, on_false.as_ref(), ctx)?;
784 }
785 ExprKind::LambdaDef {
786 arguments,
787 body,
788 lambda_type,
789 lambda_type_params,
790 captured_generic_params,
791 lambda_unit: _,
792 } => {
793 for arg in arguments {
794 self.visit_pattern(arg, ctx, generic_list)?;
795 }
796
797 self.visit_expression(body, ctx, generic_list);
798
799 let lambda_params = arguments
800 .iter()
801 .map(|arg| arg.get_type(self))
802 .chain(vec![body.get_type(self)])
803 .chain(
804 captured_generic_params
805 .iter()
806 .map(|cap| {
807 let t = self
808 .get_generic_list(generic_list)
809 .ok_or_else(|| {
810 diag_anyhow!(
811 expression,
812 "Found a captured generic but no generic list"
813 )
814 })?
815 .get(&cap.name_in_body)
816 .ok_or_else(|| {
817 diag_anyhow!(
818 &cap.name_in_body,
819 "Did not find an entry for {} in lambda generic list",
820 cap.name_in_body
821 )
822 });
823 Ok(t?.clone())
824 })
825 .collect::<Result<Vec<_>>>()?
826 .into_iter(),
827 )
828 .collect::<Vec<_>>();
829
830 let self_type = TypeVar::Known(
831 expression.loc(),
832 KnownType::Named(lambda_type.clone()),
833 lambda_params.clone(),
834 );
835
836 let unit_generic_list = self.create_generic_list(
837 GenericListSource::Expression(expression.id),
838 &lambda_type_params,
839 &[],
840 None,
841 &[],
842 )?;
843
844 for (p, tp) in lambda_params.iter().zip(lambda_type_params) {
845 let gl = self.get_generic_list(&unit_generic_list).unwrap();
846 p.unify_with(
848 gl.get(&tp.name_id).ok_or_else(|| {
849 diag_anyhow!(
850 expression,
851 "Lambda unit list did not contain {}",
852 tp.name_id
853 )
854 })?,
855 self,
856 )
857 .commit(self, ctx)
858 .into_default_diagnostic(expression, self)?;
859 }
860 expression
861 .unify_with(&self.add_type_var(self_type), self)
862 .commit(self, ctx)
863 .into_default_diagnostic(expression, self)?;
864 }
865 ExprKind::StaticUnreachable(_) => {}
866 ExprKind::Null => {}
867 }
868 Ok(())
869 }
870
871 #[tracing::instrument(level = "trace", skip_all)]
872 pub fn visit_expression(
873 &mut self,
874 expression: &Loc<Expression>,
875 ctx: &Context,
876 generic_list: &GenericListToken,
877 ) {
878 let new_type = self.new_generic_type(expression.loc());
879 self.add_equation(TypedExpression::Id(expression.inner.id), new_type);
880
881 match self.visit_expression_result(expression, ctx, generic_list, new_type) {
882 Ok(_) => {}
883 Err(e) => {
884 new_type
885 .unify_with(&self.t_err(expression.loc()), self)
886 .commit(self, ctx)
887 .unwrap();
888
889 self.diags.errors.push(e);
890 }
891 }
892 }
893
894 #[tracing::instrument(level = "trace", skip_all, fields(%name))]
896 #[trace_typechecker]
897 fn handle_function_like(
898 &mut self,
899 expression_id: Loc<ExprID>,
900 expression_type: &TypeVarID,
901 name: &FunctionLikeName,
902 head: &Loc<UnitHead>,
903 call_kind: &CallKind,
904 args: &Loc<ArgumentList<Expression>>,
905 ctx: &Context,
906 visit_args: bool,
910 is_method: bool,
913 turbofish: Option<TurbofishCtx>,
914 generic_list: &GenericListToken,
915 ) -> Result<()> {
916 let unit_generic_list = self.create_generic_list(
918 GenericListSource::Expression(expression_id.inner),
919 &head.unit_type_params,
920 &head.scope_type_params,
921 turbofish,
922 &head.where_clauses,
923 )?;
924
925 match (&head.unit_kind.inner, call_kind) {
926 (
927 UnitKind::Pipeline {
928 depth: udepth,
929 depth_typeexpr_id: _,
930 },
931 CallKind::Pipeline {
932 inst_loc: _,
933 depth: cdepth,
934 depth_typeexpr_id: cdepth_typeexpr_id,
935 },
936 ) => {
937 let definition_depth = self.hir_type_expr_to_var(udepth, &unit_generic_list)?;
938 let call_depth = self.hir_type_expr_to_var(cdepth, generic_list)?;
939
940 self.add_equation(TypedExpression::Id(*cdepth_typeexpr_id), call_depth.clone());
944
945 self.unify(&call_depth, &definition_depth, ctx)
946 .into_diagnostic_no_expected_source(
947 cdepth,
948 |diag, tm| {
949 let (e, g) = tm.display_e_g(self);
950 diag.message("Pipeline depth mismatch")
951 .primary_label(format!("Expected depth {e}, got {g}"))
952 .secondary_label(udepth, format!("{name} has depth {e}"))
953 },
954 self,
955 )?;
956 }
957 _ => {}
958 }
959
960 if visit_args {
961 self.visit_argument_list(args, ctx, &generic_list)?;
962 }
963
964 let type_params = &head.get_type_params();
965
966 macro_rules! handle_special_functions {
968 ($([$($path:expr),*] => $handler:expr),*) => {
969 $(
970 let path = Path(vec![$(Identifier($path.to_string()).nowhere()),*]).nowhere();
971 if ctx.symtab
972 .try_lookup_final_id(&path, &[])
973 .map(|n| &FunctionLikeName::Free(n) == name)
974 .unwrap_or(false)
975 {
976 $handler
977 };
978 )*
979 }
980 }
981
982 macro_rules! generic_arg {
985 ($idx:expr) => {
986 self.get_generic_list(&unit_generic_list)
987 .ok_or_else(|| diag_anyhow!(expression_id, "Found no generic list for call"))?
988 [&type_params[$idx].name_id()]
989 .clone()
990 };
991 }
992
993 let matched_args =
994 match_args_with_params(args, &head.inputs.inner, is_method).map_err(|e| {
995 let diag: Diagnostic = e.into();
996 diag.secondary_label(
997 head,
998 format!("{kind} defined here", kind = head.unit_kind.name()),
999 )
1000 })?;
1001
1002 handle_special_functions! {
1003 ["std", "conv", "concat"] => {
1004 self.handle_concat(
1005 expression_id,
1006 generic_arg!(0),
1007 generic_arg!(1),
1008 generic_arg!(2),
1009 &matched_args,
1010 ctx
1011 )?
1012 },
1013 ["std", "conv", "trunc"] => {
1014 self.handle_trunc(
1015 expression_id,
1016 generic_arg!(0),
1017 generic_arg!(1),
1018 &matched_args,
1019 ctx
1020 )?
1021 },
1022 ["std", "ops", "comb_div"] => {
1023 self.handle_comb_mod_or_div(
1024 generic_arg!(0),
1025 &matched_args,
1026 ctx
1027 )?
1028 },
1029 ["std", "ops", "comb_mod"] => {
1030 self.handle_comb_mod_or_div(
1031 generic_arg!(0),
1032 &matched_args,
1033 ctx
1034 )?
1035 },
1036 ["std", "mem", "clocked_memory"] => {
1037 let num_elements = generic_arg!(0);
1038 let addr_size = generic_arg!(2);
1039
1040 self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1041 },
1042 ["std", "mem", "clocked_memory_init"] => {
1045 let num_elements = generic_arg!(0);
1046 let addr_size = generic_arg!(2);
1047
1048 self.handle_clocked_memory(num_elements, addr_size, &matched_args, ctx)?
1049 },
1050 ["std", "mem", "read_memory"] => {
1051 let addr_size = generic_arg!(0);
1052 let num_elements = generic_arg!(2);
1053
1054 self.handle_read_memory(num_elements, addr_size, &matched_args, ctx)?
1055 }
1056 };
1057
1058 self.type_check_argument_list(&matched_args, ctx, &unit_generic_list)?;
1060
1061 let return_type = head
1062 .output_type
1063 .as_ref()
1064 .map(|o| self.type_var_from_hir(expression_id.loc(), o, &unit_generic_list))
1065 .transpose()?
1066 .unwrap_or_else(|| {
1067 self.add_type_var(TypeVar::Known(
1068 expression_id.loc(),
1069 KnownType::Tuple,
1070 vec![],
1071 ))
1072 });
1073
1074 self.unify(expression_type, &return_type, ctx)
1075 .into_default_diagnostic(expression_id.loc(), self)?;
1076
1077 Ok(())
1078 }
1079
1080 pub fn handle_concat(
1081 &mut self,
1082 expression_id: Loc<ExprID>,
1083 source_lhs_ty: TypeVarID,
1084 source_rhs_ty: TypeVarID,
1085 source_result_ty: TypeVarID,
1086 args: &[Argument<Expression, TypeSpec>],
1087 ctx: &Context,
1088 ) -> Result<()> {
1089 let (lhs_type, lhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1090 let (rhs_type, rhs_size) = self.new_generic_number(expression_id.loc(), ctx);
1091 let (result_type, result_size) = self.new_generic_number(expression_id.loc(), ctx);
1092 self.unify(&source_lhs_ty, &lhs_type, ctx)
1093 .into_default_diagnostic(args[0].value.loc(), self)?;
1094 self.unify(&source_rhs_ty, &rhs_type, ctx)
1095 .into_default_diagnostic(args[1].value.loc(), self)?;
1096 self.unify(&source_result_ty, &result_type, ctx)
1097 .into_default_diagnostic(expression_id.loc(), self)?;
1098
1099 self.add_constraint(
1101 result_size.clone(),
1102 ce_var(&lhs_size) + ce_var(&rhs_size),
1103 expression_id.loc(),
1104 &result_size,
1105 ConstraintSource::Concatenation,
1106 );
1107 self.add_constraint(
1108 lhs_size.clone(),
1109 ce_var(&result_size) + -ce_var(&rhs_size),
1110 args[0].value.loc(),
1111 &lhs_size,
1112 ConstraintSource::Concatenation,
1113 );
1114 self.add_constraint(
1115 rhs_size.clone(),
1116 ce_var(&result_size) + -ce_var(&lhs_size),
1117 args[1].value.loc(),
1118 &rhs_size,
1119 ConstraintSource::Concatenation,
1120 );
1121
1122 self.add_requirement(Requirement::SharedBase(vec![
1123 lhs_type.at_loc(args[0].value),
1124 rhs_type.at_loc(args[1].value),
1125 result_type.at_loc(&expression_id.loc()),
1126 ]));
1127 Ok(())
1128 }
1129
1130 pub fn handle_trunc(
1131 &mut self,
1132 expression_id: Loc<ExprID>,
1133 source_in_ty: TypeVarID,
1134 source_result_ty: TypeVarID,
1135 args: &[Argument<Expression, TypeSpec>],
1136 ctx: &Context,
1137 ) -> Result<()> {
1138 let (in_ty, _) = self.new_generic_number(expression_id.loc(), ctx);
1139 let (result_type, _) = self.new_generic_number(expression_id.loc(), ctx);
1140 self.unify(&source_in_ty, &in_ty, ctx)
1141 .into_default_diagnostic(args[0].value.loc(), self)?;
1142 self.unify(&source_result_ty, &result_type, ctx)
1143 .into_default_diagnostic(expression_id.loc(), self)?;
1144
1145 self.add_requirement(Requirement::SharedBase(vec![
1146 in_ty.at_loc(args[0].value),
1147 result_type.at_loc(&expression_id.loc()),
1148 ]));
1149 Ok(())
1150 }
1151
1152 pub fn handle_comb_mod_or_div(
1153 &mut self,
1154 n_ty: TypeVarID,
1155 args: &[Argument<Expression, TypeSpec>],
1156 ctx: &Context,
1157 ) -> Result<()> {
1158 let (num, _) = self.new_generic_number(args[0].value.loc(), ctx);
1159 self.unify(&n_ty, &num, ctx)
1160 .into_default_diagnostic(args[0].value.loc(), self)?;
1161 Ok(())
1162 }
1163
1164 pub fn handle_clocked_memory(
1165 &mut self,
1166 num_elements: TypeVarID,
1167 addr_size_arg: TypeVarID,
1168 args: &[Argument<Expression, TypeSpec>],
1169 ctx: &Context,
1170 ) -> Result<()> {
1171 let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1174 let arg1_loc = args[1].value.loc();
1175 let tup = TypeVar::tuple(
1176 args[1].value.loc(),
1177 vec![
1178 self.new_generic_type(arg1_loc),
1179 addr_type,
1180 self.new_generic_type(arg1_loc),
1181 ],
1182 );
1183 let port_type = TypeVar::array(
1184 arg1_loc,
1185 self.add_type_var(tup),
1186 self.new_generic_tluint(arg1_loc),
1187 )
1188 .insert(self);
1189
1190 self.add_constraint(
1191 addr_size.clone(),
1192 bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1193 args[1].value.loc(),
1194 &port_type,
1195 ConstraintSource::MemoryIndexing,
1196 );
1197
1198 self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1200 self.unify_expression_generic_error(args[1].value, &port_type, ctx)?;
1201
1202 Ok(())
1203 }
1204
1205 pub fn handle_read_memory(
1206 &mut self,
1207 num_elements: TypeVarID,
1208 addr_size_arg: TypeVarID,
1209 args: &[Argument<Expression, TypeSpec>],
1210 ctx: &Context,
1211 ) -> Result<()> {
1212 let (addr_type, addr_size) = self.new_split_generic_uint(args[1].value.loc(), ctx.symtab);
1213
1214 self.add_constraint(
1215 addr_size.clone(),
1216 bits_to_store(ce_var(&num_elements) - ce_int(1.to_bigint())),
1217 args[1].value.loc(),
1218 &addr_type,
1219 ConstraintSource::MemoryIndexing,
1220 );
1221
1222 self.unify(&addr_size, &addr_size_arg, ctx).unwrap();
1224
1225 Ok(())
1226 }
1227
1228 #[tracing::instrument(level = "trace", skip(self, turbofish, where_clauses))]
1229 pub fn create_generic_list(
1230 &mut self,
1231 source: GenericListSource,
1232 type_params: &[Loc<TypeParam>],
1233 scope_type_params: &[Loc<TypeParam>],
1234 turbofish: Option<TurbofishCtx>,
1235 where_clauses: &[Loc<WhereClause>],
1236 ) -> Result<GenericListToken> {
1237 let turbofish_params = if let Some(turbofish) = turbofish.as_ref() {
1238 if type_params.is_empty() {
1239 return Err(Diagnostic::error(
1240 turbofish.turbofish,
1241 "Turbofish on non-generic function",
1242 )
1243 .primary_label("Turbofish on non-generic function"));
1244 }
1245
1246 let matched_params =
1247 param_util::match_args_with_params(turbofish.turbofish, &type_params, false)?;
1248
1249 matched_params
1253 .iter()
1254 .map(|matched_param| {
1255 let i = type_params
1256 .iter()
1257 .enumerate()
1258 .find_map(|(i, param)| match ¶m.inner {
1259 TypeParam {
1260 ident,
1261 name_id: _,
1262 trait_bounds: _,
1263 meta: _,
1264 } => {
1265 if ident == matched_param.target {
1266 Some(i)
1267 } else {
1268 None
1269 }
1270 }
1271 })
1272 .unwrap();
1273 (i, matched_param)
1274 })
1275 .sorted_by_key(|(i, _)| *i)
1276 .map(|(_, mp)| Some(mp.value))
1277 .collect::<Vec<_>>()
1278 } else {
1279 type_params.iter().map(|_| None).collect::<Vec<_>>()
1280 };
1281
1282 let mut inline_trait_bounds: Vec<Loc<WhereClause>> = vec![];
1283
1284 let scope_type_params = scope_type_params
1285 .iter()
1286 .map(|param| {
1287 let hir::TypeParam {
1288 ident,
1289 name_id,
1290 trait_bounds,
1291 meta,
1292 } = ¶m.inner;
1293 if !trait_bounds.is_empty() {
1294 if let MetaType::Type = meta {
1295 inline_trait_bounds.push(
1296 WhereClause::Type {
1297 target: name_id.clone().at_loc(ident),
1298 traits: trait_bounds.clone(),
1299 }
1300 .at_loc(param),
1301 );
1302 } else {
1303 return Err(Diagnostic::bug(param, "Trait bounds on generic int")
1304 .primary_label("Trait bounds are only allowed on type parameters"));
1305 }
1306 }
1307 Ok((
1308 name_id.clone(),
1309 self.new_generic_with_meta(param.loc(), meta.clone()),
1310 ))
1311 })
1312 .collect::<Result<Vec<_>>>()?;
1313
1314 let new_list = type_params
1315 .iter()
1316 .enumerate()
1317 .map(|(i, param)| {
1318 let hir::TypeParam {
1319 ident,
1320 name_id,
1321 trait_bounds,
1322 meta,
1323 } = ¶m.inner;
1324
1325 let t = self.new_generic_with_meta(param.loc(), meta.clone());
1326
1327 if let Some(tf) = &turbofish_params[i] {
1328 let tf_ctx = turbofish.as_ref().unwrap();
1329 let ty = self.hir_type_expr_to_var(tf, tf_ctx.prev_generic_list)?;
1330 self.unify(&ty, &t, tf_ctx.type_ctx)
1331 .into_default_diagnostic(param, self)?;
1332 }
1333
1334 if !trait_bounds.is_empty() {
1335 if let MetaType::Type = meta {
1336 inline_trait_bounds.push(
1337 WhereClause::Type {
1338 target: name_id.clone().at_loc(ident),
1339 traits: trait_bounds.clone(),
1340 }
1341 .at_loc(param),
1342 );
1343 }
1344 Ok((name_id.clone(), t))
1345 } else {
1346 Ok((name_id.clone(), t))
1347 }
1348 })
1349 .collect::<Result<Vec<_>>>()?
1350 .into_iter()
1351 .chain(scope_type_params.into_iter())
1352 .map(|(name, t)| (name, t.clone()))
1353 .collect::<HashMap<_, _>>();
1354
1355 self.trace_stack.push(TraceStackEntry::NewGenericList(
1356 new_list
1357 .iter()
1358 .map(|(name, var)| (name.clone(), var.debug_resolve(self)))
1359 .collect(),
1360 ));
1361
1362 let token = self.add_mapped_generic_list(source, new_list.clone());
1363
1364 for constraint in where_clauses.iter().chain(inline_trait_bounds.iter()) {
1365 match &constraint.inner {
1366 WhereClause::Type { target, traits } => {
1367 self.visit_trait_bounds(target, traits.as_slice(), &token)?;
1368 }
1369 WhereClause::Int { target, constraint } => {
1370 let int_constraint = self.visit_const_generic(constraint, &token)?;
1371 let tvar = new_list.get(target).ok_or_else(|| {
1372 Diagnostic::error(
1373 target,
1374 format!("{target} is not a generic parameter on this unit"),
1375 )
1376 .primary_label("Not a generic parameter")
1377 })?;
1378
1379 self.add_constraint(
1380 tvar.clone(),
1381 int_constraint,
1382 constraint.loc(),
1383 &tvar,
1384 ConstraintSource::Where,
1385 );
1386 }
1387 }
1388 }
1389
1390 Ok(token)
1391 }
1392
1393 pub fn add_mapped_generic_list(
1395 &mut self,
1396 source: GenericListSource,
1397 mapping: HashMap<NameID, TypeVarID>,
1398 ) -> GenericListToken {
1399 let reference = match source {
1400 GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1401 GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1402 GenericListSource::ImplBlock { target, id } => {
1403 GenericListToken::ImplBlock(target.clone(), id)
1404 }
1405 GenericListSource::Expression(id) => GenericListToken::Expression(id),
1406 };
1407
1408 if self
1409 .generic_lists
1410 .insert(reference.clone(), mapping)
1411 .is_some()
1412 {
1413 panic!("A generic list already existed for {reference:?}");
1414 }
1415 reference
1416 }
1417
1418 pub fn remove_generic_list(&mut self, source: GenericListSource) {
1419 let reference = match source {
1420 GenericListSource::Anonymous => GenericListToken::Anonymous(self.generic_lists.len()),
1421 GenericListSource::Definition(name) => GenericListToken::Definition(name.clone()),
1422 GenericListSource::ImplBlock { target, id } => {
1423 GenericListToken::ImplBlock(target.clone(), id)
1424 }
1425 GenericListSource::Expression(id) => GenericListToken::Expression(id),
1426 };
1427
1428 self.generic_lists.remove(&reference.clone());
1429 }
1430
1431 #[tracing::instrument(level = "trace", skip_all)]
1432 #[trace_typechecker]
1433 pub fn visit_block(
1434 &mut self,
1435 block: &Block,
1436 ctx: &Context,
1437 generic_list: &GenericListToken,
1438 ) -> Result<()> {
1439 for statement in &block.statements {
1440 self.visit_statement(statement, ctx, generic_list);
1441 }
1442 if let Some(result) = &block.result {
1443 self.visit_expression(result, ctx, generic_list);
1444 }
1445 Ok(())
1446 }
1447
1448 #[tracing::instrument(level = "trace", skip_all)]
1449 pub fn visit_impl_blocks(&mut self, item_list: &ItemList) -> TraitImplList {
1450 let mut trait_impls = TraitImplList::new();
1451 for (target, impls) in &item_list.impls {
1452 for ((trait_name, type_expressions), impl_block) in impls {
1453 let result = (|| {
1454 let generic_list = self.create_generic_list(
1455 GenericListSource::ImplBlock {
1456 target,
1457 id: impl_block.id,
1458 },
1459 &[],
1460 impl_block.type_params.as_slice(),
1461 None,
1462 &[],
1463 )?;
1464
1465 let loc = trait_name
1466 .name_loc()
1467 .map(|n| ().at_loc(&n))
1468 .unwrap_or(().at_loc(&impl_block));
1469
1470 let trait_type_params = type_expressions
1471 .iter()
1472 .map(|param| {
1473 Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1474 ¶m.clone().at_loc(&loc),
1475 &generic_list,
1476 )?))
1477 })
1478 .collect::<Result<_>>()?;
1479
1480 let target_type_params = impl_block
1481 .target
1482 .type_params()
1483 .into_iter()
1484 .map(|param| {
1485 Ok(TemplateTypeVarID::new(self.hir_type_expr_to_var(
1486 ¶m.clone().at_loc(&loc),
1487 &generic_list,
1488 )?))
1489 })
1490 .collect::<Result<_>>()?;
1491
1492 trait_impls
1493 .inner
1494 .entry(target.clone())
1495 .or_default()
1496 .push(TraitImpl {
1497 name: trait_name.clone(),
1498 target_type_params,
1499 trait_type_params,
1500
1501 impl_block: impl_block.inner.clone(),
1502 });
1503
1504 Ok(())
1505 })();
1506
1507 match result {
1508 Ok(()) => {}
1509 Err(e) => self.diags.errors.push(e),
1510 }
1511 }
1512 }
1513
1514 trait_impls
1515 }
1516
1517 #[trace_typechecker]
1518 pub fn visit_pattern(
1519 &mut self,
1520 pattern: &Loc<Pattern>,
1521 ctx: &Context,
1522 generic_list: &GenericListToken,
1523 ) -> Result<()> {
1524 let new_type = self.new_generic_type(pattern.loc());
1525 self.add_equation(TypedExpression::Id(pattern.inner.id), new_type);
1526 match &pattern.inner.kind {
1527 hir::PatternKind::Integer(val) => {
1528 let (num_t, _) = &self.new_generic_number(pattern.loc(), ctx);
1529 self.add_requirement(Requirement::FitsIntLiteral {
1530 value: ConstantInt::Literal(val.clone()),
1531 target_type: num_t.clone().at_loc(pattern),
1532 });
1533 self.unify(pattern, num_t, ctx)
1534 .expect("Failed to unify new_generic with int");
1535 }
1536 hir::PatternKind::Bool(_) => {
1537 pattern
1538 .unify_with(&self.t_bool(pattern.loc(), ctx.symtab), self)
1539 .commit(self, ctx)
1540 .expect("Expected new_generic with boolean");
1541 }
1542 hir::PatternKind::Name { name, pre_declared } => {
1543 if !pre_declared {
1544 self.add_equation(
1545 TypedExpression::Name(name.clone().inner),
1546 pattern.get_type(self),
1547 );
1548 }
1549 self.unify(
1550 &TypedExpression::Id(pattern.id),
1551 &TypedExpression::Name(name.clone().inner),
1552 ctx,
1553 )
1554 .into_default_diagnostic(name.loc(), self)?;
1555 }
1556 hir::PatternKind::Tuple(subpatterns) => {
1557 for pattern in subpatterns {
1558 self.visit_pattern(pattern, ctx, generic_list)?;
1559 }
1560 let tuple_type = self.add_type_var(TypeVar::tuple(
1561 pattern.loc(),
1562 subpatterns
1563 .iter()
1564 .map(|pattern| {
1565 let p_type = pattern.get_type(self);
1566 Ok(p_type)
1567 })
1568 .collect::<Result<_>>()?,
1569 ));
1570
1571 self.unify(pattern, &tuple_type, ctx)
1572 .expect("Unification of new_generic with tuple type cannot fail");
1573 }
1574 hir::PatternKind::Array(inner) => {
1575 for pattern in inner {
1576 self.visit_pattern(pattern, ctx, generic_list)?;
1577 }
1578 if inner.len() == 0 {
1579 return Err(
1580 Diagnostic::error(pattern, "Empty array patterns are unsupported")
1581 .primary_label("Empty array pattern"),
1582 );
1583 } else {
1584 let inner_t = inner[0].get_type(self);
1585
1586 for pattern in inner.iter().skip(1) {
1587 self.unify(pattern, &inner_t, ctx)
1588 .into_default_diagnostic(pattern, self)?;
1589 }
1590
1591 pattern
1592 .unify_with(
1593 &TypeVar::Known(
1594 pattern.loc(),
1595 KnownType::Array,
1596 vec![
1597 inner_t,
1598 self.add_type_var(TypeVar::Known(
1599 pattern.loc(),
1600 KnownType::Integer(inner.len().to_bigint()),
1601 vec![],
1602 )),
1603 ],
1604 )
1605 .insert(self),
1606 self,
1607 )
1608 .commit(self, ctx)
1609 .into_default_diagnostic(pattern, self)?;
1610 }
1611 }
1612 hir::PatternKind::Type(name, args) => {
1613 let (condition_type, params, generic_list) =
1614 match ctx.symtab.patternable_type_by_id(name).inner {
1615 Patternable {
1616 kind: PatternableKind::Enum,
1617 params: _,
1618 } => {
1619 let enum_variant = ctx.symtab.enum_variant_by_id(name).inner;
1620 let generic_list = self.create_generic_list(
1621 GenericListSource::Anonymous,
1622 &enum_variant.type_params,
1623 &[],
1624 None,
1625 &[],
1626 )?;
1627
1628 let condition_type = self.type_var_from_hir(
1629 pattern.loc(),
1630 &enum_variant.output_type,
1631 &generic_list,
1632 )?;
1633
1634 (condition_type, enum_variant.params, generic_list)
1635 }
1636 Patternable {
1637 kind: PatternableKind::Struct,
1638 params: _,
1639 } => {
1640 let s = ctx.symtab.struct_by_id(name).inner;
1641 let generic_list = self.create_generic_list(
1642 GenericListSource::Anonymous,
1643 &s.type_params,
1644 &[],
1645 None,
1646 &[],
1647 )?;
1648
1649 let condition_type =
1650 self.type_var_from_hir(pattern.loc(), &s.self_type, &generic_list)?;
1651
1652 (condition_type, s.params, generic_list)
1653 }
1654 };
1655
1656 self.unify(pattern, &condition_type, ctx)
1657 .expect("Unification of new_generic with enum cannot fail");
1658
1659 for (
1660 PatternArgument {
1661 target,
1662 value: pattern,
1663 kind,
1664 },
1665 Parameter {
1666 name: _,
1667 ty: target_type,
1668 no_mangle: _,
1669 field_translator: _,
1670 },
1671 ) in args.iter().zip(params.0.iter())
1672 {
1673 self.visit_pattern(pattern, ctx, &generic_list)?;
1674 let target_type =
1675 self.type_var_from_hir(target_type.loc(), target_type, &generic_list)?;
1676
1677 let loc = match kind {
1678 hir::ArgumentKind::Positional => pattern.loc(),
1679 hir::ArgumentKind::Named => pattern.loc(),
1680 hir::ArgumentKind::ShortNamed => target.loc(),
1681 };
1682
1683 self.unify(pattern, &target_type, ctx).into_diagnostic(
1684 loc,
1685 |d, tm| {
1686 let (expected, got) = tm.display_e_g(self);
1687 d.message(format!(
1688 "Argument type mismatch. Expected {expected} got {got}"
1689 ))
1690 .primary_label(format!("expected {expected}"))
1691 },
1692 self,
1693 )?;
1694 }
1695 }
1696 }
1697 Ok(())
1698 }
1699
1700 #[trace_typechecker]
1701 pub fn visit_wal_trace(
1702 &mut self,
1703 trace: &Loc<WalTrace>,
1704 ctx: &Context,
1705 generic_list: &GenericListToken,
1706 ) -> Result<()> {
1707 let WalTrace { clk, rst } = &trace.inner;
1708 clk.as_ref()
1709 .map(|x| {
1710 self.visit_expression(x, ctx, generic_list);
1711 x.unify_with(&self.t_clock(trace.loc(), ctx.symtab), self)
1712 .commit(self, ctx)
1713 .into_default_diagnostic(x, self)
1714 })
1715 .transpose()?;
1716 rst.as_ref()
1717 .map(|x| {
1718 self.visit_expression(x, ctx, generic_list);
1719 x.unify_with(&self.t_bool(trace.loc(), ctx.symtab), self)
1720 .commit(self, ctx)
1721 .into_default_diagnostic(x, self)
1722 })
1723 .transpose()?;
1724 Ok(())
1725 }
1726
1727 #[trace_typechecker]
1728 pub fn visit_statement_error(
1729 &mut self,
1730 stmt: &Loc<Statement>,
1731 ctx: &Context,
1732 generic_list: &GenericListToken,
1733 ) -> Result<()> {
1734 match &stmt.inner {
1735 Statement::Error => {
1736 if let Some(current_stage_depth) =
1737 self.pipeline_state.as_ref().map(|s| s.current_stage_depth)
1738 {
1739 current_stage_depth
1740 .unify_with(&self.t_err(stmt.loc()), self)
1741 .commit(self, ctx)
1742 .unwrap();
1743 }
1744 Ok(())
1745 }
1746 Statement::Binding(Binding {
1747 pattern,
1748 ty,
1749 value,
1750 wal_trace,
1751 }) => {
1752 trace!("Visiting `let {} = ..`", pattern.kind);
1753 self.visit_expression(value, ctx, generic_list);
1754
1755 self.visit_pattern(pattern, ctx, generic_list)
1756 .handle_in(&mut self.diags);
1757
1758 self.unify(&TypedExpression::Id(pattern.id), value, ctx)
1759 .into_diagnostic(
1760 pattern.loc(),
1761 error_pattern_type_mismatch(
1762 ty.as_ref().map(|t| t.loc()).unwrap_or_else(|| value.loc()),
1763 self,
1764 ),
1765 self,
1766 )
1767 .handle_in(&mut self.diags);
1768
1769 if let Some(t) = ty {
1770 let tvar = self.type_var_from_hir(t.loc(), t, generic_list)?;
1771 self.unify(&TypedExpression::Id(pattern.id), &tvar, ctx)
1772 .into_default_diagnostic(value.loc(), self)
1773 .handle_in(&mut self.diags);
1774 }
1775
1776 wal_trace
1777 .as_ref()
1778 .map(|wt| self.visit_wal_trace(wt, ctx, generic_list))
1779 .transpose()
1780 .handle_in(&mut self.diags);
1781
1782 Ok(())
1783 }
1784 Statement::Expression(expr) => {
1785 self.visit_expression(expr, ctx, generic_list);
1786 Ok(())
1787 }
1788 Statement::Register(reg) => self.visit_register(reg, ctx, generic_list),
1789 Statement::Declaration(names) => {
1790 for name in names {
1791 let new_type = self.new_generic_type(name.loc());
1792 self.add_equation(TypedExpression::Name(name.clone().inner), new_type);
1793 }
1794 Ok(())
1795 }
1796 Statement::PipelineRegMarker(extra) => {
1797 match extra {
1798 Some(PipelineRegMarkerExtra::Condition(cond)) => {
1799 self.visit_expression(cond, ctx, generic_list);
1800 cond.unify_with(&self.t_bool(cond.loc(), ctx.symtab), self)
1801 .commit(self, ctx)
1802 .into_default_diagnostic(cond, self)?;
1803 }
1804 Some(PipelineRegMarkerExtra::Count {
1805 count: _,
1806 count_typeexpr_id: _,
1807 }) => {}
1808 None => {}
1809 }
1810
1811 let current_stage_depth = self
1812 .pipeline_state
1813 .clone()
1814 .ok_or_else(|| {
1815 diag_anyhow!(stmt, "Found a pipeline reg marker in a non-pipeline")
1816 })?
1817 .current_stage_depth;
1818
1819 let new_depth = self.new_generic_tlint(stmt.loc());
1820 let offset = match extra {
1821 Some(PipelineRegMarkerExtra::Count {
1822 count,
1823 count_typeexpr_id,
1824 }) => {
1825 let var = self.hir_type_expr_to_var(count, generic_list)?;
1826 self.add_equation(TypedExpression::Id(*count_typeexpr_id), var.clone());
1827 var
1828 }
1829 Some(PipelineRegMarkerExtra::Condition(_)) | None => self.add_type_var(
1830 TypeVar::Known(stmt.loc(), KnownType::Integer(1.to_bigint()), vec![]),
1831 ),
1832 };
1833
1834 let total_depth = ConstraintExpr::Sum(
1835 Box::new(ConstraintExpr::Var(offset)),
1836 Box::new(ConstraintExpr::Var(current_stage_depth)),
1837 );
1838 self.pipeline_state
1839 .as_mut()
1840 .expect("Expected to have a pipeline state")
1841 .current_stage_depth = new_depth.clone();
1842
1843 self.add_constraint(
1844 new_depth.clone(),
1845 total_depth,
1846 stmt.loc(),
1847 &new_depth,
1848 ConstraintSource::PipelineRegCount {
1849 reg: stmt.loc(),
1850 total: self.get_pipeline_state(stmt)?.total_depth.loc(),
1851 },
1852 );
1853
1854 Ok(())
1855 }
1856 Statement::Label(name) => {
1857 let key = TypedExpression::Name(name.inner.clone());
1858 let var = if !self.equations.contains_key(&key) {
1859 let var = self.new_generic_tlint(name.loc());
1860 self.trace_stack.push(TraceStackEntry::AddingPipelineLabel(
1861 name.inner.clone(),
1862 var.debug_resolve(self),
1863 ));
1864 self.add_equation(key.clone(), var.clone());
1865 var
1866 } else {
1867 let var = self.equations.get(&key).unwrap().clone();
1868 self.trace_stack
1869 .push(TraceStackEntry::RecoveringPipelineLabel(
1870 name.inner.clone(),
1871 var.debug_resolve(self),
1872 ));
1873 var
1874 };
1875 self.unify(
1877 &var,
1878 &self.get_pipeline_state(name)?.current_stage_depth.clone(),
1879 ctx,
1880 )
1881 .unwrap();
1882 Ok(())
1883 }
1884 Statement::WalSuffixed { .. } => Ok(()),
1885 Statement::Assert(expr) => {
1886 self.visit_expression(expr, ctx, generic_list);
1887
1888 expr.unify_with(&self.t_bool(stmt.loc(), ctx.symtab), self)
1889 .commit(self, ctx)
1890 .into_default_diagnostic(expr, self)
1891 .handle_in(&mut self.diags);
1892 Ok(())
1893 }
1894 Statement::Set { target, value } => {
1895 self.visit_expression(target, ctx, generic_list);
1896 self.visit_expression(value, ctx, generic_list);
1897
1898 let inner_type = self.new_generic_type(value.loc());
1899 let outer_type =
1900 TypeVar::backward(stmt.loc(), inner_type.clone(), self).insert(self);
1901 self.unify_expression_generic_error(target, &outer_type, ctx)
1902 .handle_in(&mut self.diags);
1903 self.unify_expression_generic_error(value, &inner_type, ctx)
1904 .handle_in(&mut self.diags);
1905
1906 Ok(())
1907 }
1908 }
1909 }
1910
1911 pub fn visit_statement(
1912 &mut self,
1913 stmt: &Loc<Statement>,
1914 ctx: &Context,
1915 generic_list: &GenericListToken,
1916 ) {
1917 if let Err(e) = self.visit_statement_error(stmt, ctx, generic_list) {
1918 self.diags.errors.push(e);
1919 }
1920 }
1921
1922 #[trace_typechecker]
1923 pub fn visit_register(
1924 &mut self,
1925 reg: &Register,
1926 ctx: &Context,
1927 generic_list: &GenericListToken,
1928 ) -> Result<()> {
1929 self.visit_pattern(®.pattern, ctx, generic_list)?;
1930
1931 let type_spec_type = match ®.value_type {
1932 Some(t) => Some(self.type_var_from_hir(t.loc(), t, generic_list)?.at_loc(t)),
1933 None => None,
1934 };
1935
1936 if let Some(tvar) = &type_spec_type {
1939 self.unify(&TypedExpression::Id(reg.pattern.id), tvar, ctx)
1940 .into_diagnostic_no_expected_source(
1941 reg.pattern.loc(),
1942 error_pattern_type_mismatch(tvar.loc(), self),
1943 self,
1944 )?;
1945 }
1946
1947 self.visit_expression(®.clock, ctx, generic_list);
1948 self.visit_expression(®.value, ctx, generic_list);
1949
1950 if let Some(tvar) = &type_spec_type {
1951 self.unify(®.value, tvar, ctx)
1952 .into_default_diagnostic(reg.value.loc(), self)?;
1953 }
1954
1955 if let Some((rst_cond, rst_value)) = ®.reset {
1956 self.visit_expression(rst_cond, ctx, generic_list);
1957 self.visit_expression(rst_value, ctx, generic_list);
1958 rst_cond
1960 .unify_with(&self.t_bool(rst_cond.loc(), ctx.symtab), self)
1961 .commit(self, ctx)
1962 .into_diagnostic(
1963 rst_cond.loc(),
1964 |diag,
1965 Tm {
1966 g: got,
1967 e: _expected,
1968 }| {
1969 diag.message(format!(
1970 "Register reset condition must be a bool, got {got}",
1971 got = got.display(self)
1972 ))
1973 .primary_label("expected bool")
1974 },
1975 self,
1976 )?;
1977
1978 self.unify(&rst_value.inner, ®.value.inner, ctx)
1980 .into_diagnostic(
1981 rst_value.loc(),
1982 |diag, tm| {
1983 let (expected, got) = tm.display_e_g(self);
1984 diag.message(format!(
1985 "Register reset value mismatch. Expected {expected} got {got}"
1986 ))
1987 .primary_label(format!("expected {expected}"))
1988 .secondary_label(®.pattern, format!("because this has type {expected}"))
1989 },
1990 self,
1991 )?;
1992 }
1993
1994 if let Some(initial) = ®.initial {
1995 self.visit_expression(initial, ctx, generic_list);
1996
1997 self.unify(&initial.inner, ®.value.inner, ctx)
1998 .into_diagnostic(
1999 initial.loc(),
2000 |diag, tm| {
2001 let (expected, got) = tm.display_e_g(self);
2002 diag.message(format!(
2003 "Register initial value mismatch. Expected {expected} got {got}"
2004 ))
2005 .primary_label(format!("expected {expected}, got {got}"))
2006 .secondary_label(®.pattern, format!("because this has type {got}"))
2007 },
2008 self,
2009 )?;
2010 }
2011
2012 reg.clock
2013 .unify_with(&self.t_clock(reg.clock.loc(), ctx.symtab), self)
2014 .commit(self, ctx)
2015 .into_diagnostic(
2016 reg.clock.loc(),
2017 |diag,
2018 Tm {
2019 g: got,
2020 e: _expected,
2021 }| {
2022 diag.message(format!(
2023 "Expected clock, got {got}",
2024 got = got.display(self)
2025 ))
2026 .primary_label("expected clock")
2027 },
2028 self,
2029 )?;
2030
2031 self.unify(&TypedExpression::Id(reg.pattern.id), ®.value, ctx)
2032 .into_diagnostic(
2033 reg.pattern.loc(),
2034 error_pattern_type_mismatch(reg.value.loc(), self),
2035 self,
2036 )?;
2037
2038 Ok(())
2039 }
2040
2041 #[trace_typechecker]
2042 pub fn visit_trait_spec(
2043 &mut self,
2044 trait_spec: &Loc<TraitSpec>,
2045 generic_list: &GenericListToken,
2046 ) -> Result<Loc<TraitReq>> {
2047 let type_params = if let Some(type_params) = &trait_spec.inner.type_params {
2048 type_params
2049 .inner
2050 .iter()
2051 .map(|te| self.hir_type_expr_to_var(te, generic_list))
2052 .collect::<Result<_>>()?
2053 } else {
2054 vec![]
2055 };
2056
2057 Ok(TraitReq {
2058 name: trait_spec.name.clone(),
2059 type_params,
2060 }
2061 .at_loc(trait_spec))
2062 }
2063
2064 #[trace_typechecker]
2065 pub fn visit_trait_bounds(
2066 &mut self,
2067 target: &Loc<NameID>,
2068 traits: &[Loc<TraitSpec>],
2069 generic_list_tok: &GenericListToken,
2070 ) -> Result<()> {
2071 let trait_reqs = traits
2072 .iter()
2073 .map(|spec| self.visit_trait_spec(spec, generic_list_tok))
2074 .collect::<Result<BTreeSet<_>>>()?
2075 .into_iter()
2076 .collect_vec();
2077
2078 if !trait_reqs.is_empty() {
2079 let trait_list = TraitList::from_vec(trait_reqs);
2080
2081 let generic_list = self.generic_lists.get(generic_list_tok).unwrap();
2082
2083 let Some(tvar) = generic_list.get(&target.inner) else {
2084 return Err(Diagnostic::bug(
2085 target,
2086 "Couldn't find generic from where clause in generic list",
2087 )
2088 .primary_label(format!(
2089 "Generic type {} not found in generic list",
2090 target.inner
2091 )));
2092 };
2093
2094 self.trace_stack.push(TraceStackEntry::AddingTraitBounds(
2095 tvar.debug_resolve(self),
2096 trait_list.clone(),
2097 ));
2098
2099 let TypeVar::Unknown(loc, id, old_trait_list, MetaType::Type) = tvar.resolve(self)
2100 else {
2101 return Err(Diagnostic::bug(
2102 target,
2103 "Trait bounds on known type or type-level integer",
2104 )
2105 .primary_label(format!(
2106 "Trait bounds on {}, which should've been caught in ast-lowering",
2107 target.inner
2108 )));
2109 };
2110
2111 let new_tvar = self.add_type_var(TypeVar::Unknown(
2112 *loc,
2113 *id,
2114 old_trait_list.clone().extend(trait_list),
2115 MetaType::Type,
2116 ));
2117
2118 trace!(
2119 "Adding trait bound {} on type {}",
2120 new_tvar.display_with_meta(true, self),
2121 target.inner
2122 );
2123
2124 let generic_list = self.generic_lists.get_mut(generic_list_tok).unwrap();
2125 generic_list.insert(target.inner.clone(), new_tvar);
2126 }
2127
2128 Ok(())
2129 }
2130
2131 pub fn visit_const_generic_with_id(
2132 &mut self,
2133 gen: &Loc<ConstGenericWithId>,
2134 generic_list_token: &GenericListToken,
2135 constraint_source: ConstraintSource,
2136 ctx: &Context,
2137 ) -> Result<TypeVarID> {
2138 let var = match &gen.inner.inner {
2139 ConstGeneric::Name(name) => {
2140 let ty = &ctx.symtab.type_symbol_by_id(&name);
2141 match &ty.inner {
2142 TypeSymbol::Declared(_, _) => {
2143 return Err(Diagnostic::error(
2144 name,
2145 "{type_decl_kind} cannot be used in a const generic expression",
2146 )
2147 .primary_label("Type in const generic")
2148 .secondary_label(ty, "{name} is declared here"))
2149 }
2150 TypeSymbol::GenericArg { .. } | TypeSymbol::GenericMeta(MetaType::Type) => {
2151 return Err(Diagnostic::error(
2152 name,
2153 "Generic types cannot be used in const generic expressions",
2154 )
2155 .primary_label("Type in const generic")
2156 .secondary_label(ty, "{name} is declared here")
2157 .span_suggest_insert_before(
2158 "Consider making this a value",
2159 ty.loc(),
2160 "#uint ",
2161 ))
2162 }
2163 TypeSymbol::GenericMeta(MetaType::Number) => {
2164 self.new_generic_tlnumber(gen.loc())
2165 }
2166 TypeSymbol::GenericMeta(MetaType::Int) => self.new_generic_tlint(gen.loc()),
2167 TypeSymbol::GenericMeta(MetaType::Uint) => self.new_generic_tluint(gen.loc()),
2168 TypeSymbol::GenericMeta(MetaType::Bool) => self.new_generic_tlbool(gen.loc()),
2169 TypeSymbol::GenericMeta(MetaType::Any) => {
2170 diag_bail!(gen, "Found any meta type")
2171 }
2172 TypeSymbol::Alias(_) => {
2173 return Err(Diagnostic::error(
2174 gen,
2175 "Aliases are not currently supported in const generics",
2176 )
2177 .secondary_label(ty, "Alias defined here"))
2178 }
2179 }
2180 }
2181 ConstGeneric::Const(_)
2182 | ConstGeneric::Add(_, _)
2183 | ConstGeneric::Sub(_, _)
2184 | ConstGeneric::Mul(_, _)
2185 | ConstGeneric::Div(_, _)
2186 | ConstGeneric::Mod(_, _)
2187 | ConstGeneric::UintBitsToFit(_) => self.new_generic_tlnumber(gen.loc()),
2188 ConstGeneric::Eq(_, _) | ConstGeneric::NotEq(_, _) => {
2189 self.new_generic_tlbool(gen.loc())
2190 }
2191 };
2192 let constraint = self.visit_const_generic(&gen.inner.inner, generic_list_token)?;
2193 self.add_equation(TypedExpression::Id(gen.id), var.clone());
2194 self.add_constraint(var.clone(), constraint, gen.loc(), &var, constraint_source);
2195 Ok(var)
2196 }
2197
2198 #[trace_typechecker]
2199 pub fn visit_const_generic(
2200 &self,
2201 constraint: &ConstGeneric,
2202 generic_list: &GenericListToken,
2203 ) -> Result<ConstraintExpr> {
2204 let wrap = |lhs,
2205 rhs,
2206 wrapper: fn(Box<ConstraintExpr>, Box<ConstraintExpr>) -> ConstraintExpr|
2207 -> Result<_> {
2208 Ok(wrapper(
2209 Box::new(self.visit_const_generic(lhs, generic_list)?),
2210 Box::new(self.visit_const_generic(rhs, generic_list)?),
2211 ))
2212 };
2213 let constraint = match constraint {
2214 ConstGeneric::Name(n) => {
2215 let var = self
2216 .get_generic_list(generic_list)
2217 .ok_or_else(|| diag_anyhow!(n, "Found no generic list"))?
2218 .get(n)
2219 .ok_or_else(|| {
2220 Diagnostic::bug(n, "Found non-generic argument in where clause")
2221 })?;
2222 ConstraintExpr::Var(*var)
2223 }
2224 ConstGeneric::Const(val) => ConstraintExpr::Integer(val.clone()),
2225 ConstGeneric::Add(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Sum)?,
2226 ConstGeneric::Sub(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Difference)?,
2227 ConstGeneric::Mul(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Product)?,
2228 ConstGeneric::Div(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Div)?,
2229 ConstGeneric::Mod(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Mod)?,
2230 ConstGeneric::Eq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::Eq)?,
2231 ConstGeneric::NotEq(lhs, rhs) => wrap(lhs, rhs, ConstraintExpr::NotEq)?,
2232 ConstGeneric::UintBitsToFit(a) => ConstraintExpr::UintBitsToRepresent(Box::new(
2233 self.visit_const_generic(a, generic_list)?,
2234 )),
2235 };
2236 Ok(constraint)
2237 }
2238}
2239
2240impl TypeState {
2242 fn new_typeid(&self) -> u64 {
2243 let mut next = self.next_typeid.borrow_mut();
2244 let result = *next;
2245 *next += 1;
2246 result
2247 }
2248
2249 pub fn add_equation(&mut self, expression: TypedExpression, var: TypeVarID) {
2250 self.trace_stack.push(TraceStackEntry::AddingEquation(
2251 expression.clone(),
2252 var.debug_resolve(self),
2253 ));
2254 if let Some(prev) = self.equations.insert(expression.clone(), var.clone()) {
2255 let var = var.clone();
2256 let expr = expression.clone();
2257 println!("{}", format_trace_stack(self));
2258 panic!("Adding equation for {} == {} but a previous eq exists.\n\tIt was previously bound to {}", expr, var.debug_resolve(self), prev.debug_resolve(self))
2259 }
2260 }
2261
2262 fn add_constraint(
2263 &mut self,
2264 lhs: TypeVarID,
2265 rhs: ConstraintExpr,
2266 loc: Loc<()>,
2267 inside: &TypeVarID,
2268 source: ConstraintSource,
2269 ) {
2270 let replaces = lhs.clone();
2271 let rhs = rhs.with_context(&replaces, &inside, source).at_loc(&loc);
2272
2273 self.trace_stack.push(TraceStackEntry::AddingConstraint(
2274 lhs.debug_resolve(self),
2275 rhs.inner.clone(),
2276 ));
2277
2278 self.constraints.add_int_constraint(lhs, rhs);
2279 }
2280
2281 fn add_requirement(&mut self, requirement: Requirement) {
2282 self.trace_stack
2283 .push(TraceStackEntry::AddRequirement(requirement.clone()));
2284 self.requirements.push(requirement)
2285 }
2286
2287 fn unify_inner(
2292 &mut self,
2293 e1: &impl HasType,
2294 e2: &impl HasType,
2295 ctx: &Context,
2296 ) -> std::result::Result<TypeVarID, UnificationError> {
2297 let v1 = e1.get_type(self);
2298 let v2 = e2.get_type(self);
2299
2300 trace!(
2301 "Unifying {} with {}",
2302 v1.debug_resolve(self),
2303 v2.debug_resolve(self)
2304 );
2305
2306 self.trace_stack.push(TraceStackEntry::TryingUnify(
2307 v1.debug_resolve(self),
2308 v2.debug_resolve(self),
2309 ));
2310
2311 macro_rules! err_producer {
2312 () => {{
2313 self.trace_stack
2314 .push(TraceStackEntry::Message("Produced error".to_string()));
2315 UnificationError::Normal(Tm {
2316 g: UnificationTrace::new(v1),
2317 e: UnificationTrace::new(v2),
2318 })
2319 }};
2320 }
2321 macro_rules! meta_err_producer {
2322 () => {{
2323 self.trace_stack
2324 .push(TraceStackEntry::Message("Produced error".to_string()));
2325 UnificationError::MetaMismatch(Tm {
2326 g: UnificationTrace::new(v1),
2327 e: UnificationTrace::new(v2),
2328 })
2329 }};
2330 }
2331
2332 macro_rules! unify_if {
2333 ($condition:expr, $new_type:expr, $replaced_type:expr) => {
2334 if $condition {
2335 Ok(($new_type, $replaced_type))
2336 } else {
2337 Err(err_producer!())
2338 }
2339 };
2340 }
2341
2342 let unify_params = |s: &mut Self,
2343 p1: &[TypeVarID],
2344 p2: &[TypeVarID]|
2345 -> std::result::Result<(), UnificationError> {
2346 if p1.len() != p2.len() {
2347 return Err({
2348 s.trace_stack
2349 .push(TraceStackEntry::Message("Produced error".to_string()));
2350 UnificationError::Normal(Tm {
2351 g: UnificationTrace::new(v1),
2352 e: UnificationTrace::new(v2),
2353 })
2354 });
2355 }
2356
2357 for (t1, t2) in p1.iter().zip(p2.iter()) {
2358 match s.unify_inner(t1, t2, ctx) {
2359 Ok(result) => result,
2360 Err(e) => {
2361 s.trace_stack
2362 .push(TraceStackEntry::Message("Adding context".to_string()));
2363 return Err(e).add_context(v1.clone(), v2.clone());
2364 }
2365 };
2366 }
2367 Ok(())
2368 };
2369
2370 let result = match (
2373 &(v1, v1.resolve(self).clone()),
2374 &(v2, v2.resolve(self).clone()),
2375 ) {
2376 ((_, TypeVar::Known(_, KnownType::Error, _)), _) => Ok((v1, vec![v2])),
2377 (_, (_, TypeVar::Known(_, KnownType::Error, _))) => Ok((v2, vec![v1])),
2378 ((_, TypeVar::Known(_, t1, p1)), (_, TypeVar::Known(_, t2, p2))) => {
2379 match (t1, t2) {
2380 (KnownType::Integer(val1), KnownType::Integer(val2)) => {
2381 unify_if!(val1 == val2, v1, vec![])
2386 }
2387 (KnownType::Named(n1), KnownType::Named(n2)) => {
2388 match (
2389 &ctx.symtab.type_symbol_by_id(n1).inner,
2390 &ctx.symtab.type_symbol_by_id(n2).inner,
2391 ) {
2392 (TypeSymbol::Declared(_, _), TypeSymbol::Declared(_, _)) => {
2393 if n1 != n2 {
2394 return Err(err_producer!());
2395 }
2396
2397 let new_ts1 = ctx.symtab.type_symbol_by_id(n1).inner;
2398 let new_ts2 = ctx.symtab.type_symbol_by_id(n2).inner;
2399 unify_params(self, &p1, &p2)?;
2400 unify_if!(new_ts1 == new_ts2, v1, vec![])
2401 }
2402 (TypeSymbol::Declared(_, _), TypeSymbol::GenericArg { traits }) => {
2403 if !traits.is_empty() {
2404 todo!("Implement trait unifictaion");
2405 }
2406 Ok((v1, vec![]))
2407 }
2408 (TypeSymbol::GenericArg { traits }, TypeSymbol::Declared(_, _)) => {
2409 if !traits.is_empty() {
2410 todo!("Implement trait unifictaion");
2411 }
2412 Ok((v2, vec![]))
2413 }
2414 (
2415 TypeSymbol::GenericArg { traits: ltraits },
2416 TypeSymbol::GenericArg { traits: rtraits },
2417 ) => {
2418 if !ltraits.is_empty() || !rtraits.is_empty() {
2419 todo!("Implement trait unifictaion");
2420 }
2421 Ok((v1, vec![]))
2422 }
2423 (TypeSymbol::Declared(_, _), TypeSymbol::GenericMeta(_)) => todo!(),
2424 (TypeSymbol::GenericArg { traits: _ }, TypeSymbol::GenericMeta(_)) => {
2425 todo!()
2426 }
2427 (TypeSymbol::GenericMeta(_), TypeSymbol::Declared(_, _)) => todo!(),
2428 (TypeSymbol::GenericMeta(_), TypeSymbol::GenericArg { traits: _ }) => {
2429 todo!()
2430 }
2431 (TypeSymbol::Alias(_), _) | (_, TypeSymbol::Alias(_)) => {
2432 return Err(UnificationError::Specific(Diagnostic::bug(
2433 ().nowhere(),
2434 "Encountered a raw type alias during unification",
2435 )))
2436 }
2437 (TypeSymbol::GenericMeta(_), TypeSymbol::GenericMeta(_)) => todo!(),
2438 }
2439 }
2440 (KnownType::Array, KnownType::Array)
2441 | (KnownType::Tuple, KnownType::Tuple)
2442 | (KnownType::Wire, KnownType::Wire)
2443 | (KnownType::Inverted, KnownType::Inverted) => {
2444 unify_params(self, &p1, &p2)?;
2448 Ok((v1, vec![]))
2449 }
2450 (_, _) => Err(err_producer!()),
2451 }
2452 }
2453 (
2455 (_, TypeVar::Unknown(loc1, _, traits1, meta1)),
2456 (_, TypeVar::Unknown(loc2, _, traits2, meta2)),
2457 ) => {
2458 let new_loc = if meta1.is_more_concrete_than(meta2) {
2459 loc1
2460 } else {
2461 loc2
2462 };
2463 let new_t = match unify_meta(meta1, meta2) {
2464 Some(meta @ MetaType::Any) | Some(meta @ MetaType::Number) => {
2465 if traits1.inner.is_empty() || traits2.inner.is_empty() {
2466 return Err(UnificationError::Specific(diag_anyhow!(
2467 new_loc,
2468 "Inferred an any meta-type with traits"
2469 )));
2470 }
2471 self.new_generic_with_meta(*loc1, meta)
2472 }
2473 Some(MetaType::Type) => {
2474 let new_trait_names = traits1
2475 .inner
2476 .iter()
2477 .chain(traits2.inner.iter())
2478 .map(|t| t.name.clone())
2479 .collect::<BTreeSet<_>>()
2480 .into_iter()
2481 .collect::<Vec<_>>();
2482
2483 let new_traits = new_trait_names
2484 .iter()
2485 .map(
2486 |name| match (traits1.get_trait(name), traits2.get_trait(name)) {
2487 (Some(req1), Some(req2)) => {
2488 let new_params = req1
2489 .inner
2490 .type_params
2491 .iter()
2492 .zip(req2.inner.type_params.iter())
2493 .map(|(p1, p2)| self.unify(p1, p2, ctx))
2494 .collect::<std::result::Result<_, UnificationError>>(
2495 )?;
2496
2497 Ok(TraitReq {
2498 name: name.clone(),
2499 type_params: new_params,
2500 }
2501 .at_loc(req1))
2502 }
2503 (Some(t), None) => Ok(t.clone()),
2504 (None, Some(t)) => Ok(t.clone()),
2505 (None, None) => panic!("Found a trait but neither side has it"),
2506 },
2507 )
2508 .collect::<std::result::Result<Vec<_>, UnificationError>>()?;
2509
2510 self.new_generic_with_traits(*new_loc, TraitList::from_vec(new_traits))
2511 }
2512 Some(MetaType::Int) => self.new_generic_tlint(*new_loc),
2513 Some(MetaType::Uint) => self.new_generic_tluint(*new_loc),
2514 Some(MetaType::Bool) => self.new_generic_tlbool(*new_loc),
2515 None => return Err(meta_err_producer!()),
2516 };
2517 Ok((new_t, vec![v1, v2]))
2518 }
2519 (
2520 (otherid, TypeVar::Known(loc, base, params)),
2521 (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2522 )
2523 | (
2524 (ukid, TypeVar::Unknown(ukloc, _, traits, meta)),
2525 (otherid, TypeVar::Known(loc, base, params)),
2526 ) => {
2527 let trait_is_expected = match (&v1.resolve(self), &v2.resolve(self)) {
2528 (TypeVar::Known(_, _, _), _) => true,
2529 _ => false,
2530 };
2531
2532 let impls = self.ensure_impls(otherid, traits, trait_is_expected, ukloc, ctx)?;
2533
2534 self.trace_stack.push(TraceStackEntry::Message(
2535 "Unifying trait_parameters".to_string(),
2536 ));
2537 let mut new_params = params.clone();
2538 for (trait_impl, trait_req) in impls {
2539 let mut param_map = BTreeMap::new();
2540
2541 for (l, r) in trait_req
2542 .type_params
2543 .iter()
2544 .zip(trait_impl.trait_type_params)
2545 {
2546 let copy = r.make_copy_with_mapping(self, &mut param_map);
2547 self.unify(l, ©, ctx)?;
2548 }
2549
2550 new_params = trait_impl
2551 .target_type_params
2552 .iter()
2553 .zip(new_params)
2554 .map(|(l, r)| {
2555 let copy = l.make_copy_with_mapping(self, &mut param_map);
2556 self.unify(©, &r, ctx).add_context(*ukid, *otherid)
2557 })
2558 .collect::<std::result::Result<_, _>>()?
2559 }
2560
2561 match (base, meta) {
2562 (KnownType::Error, _) => {
2563 unreachable!()
2564 }
2565 (_, MetaType::Any)
2567 | (KnownType::Named(_), MetaType::Type)
2569 | (KnownType::Tuple, MetaType::Type)
2570 | (KnownType::Array, MetaType::Type)
2571 | (KnownType::Wire, MetaType::Type)
2572 | (KnownType::Bool(_), MetaType::Bool)
2573 | (KnownType::Inverted, MetaType::Type)
2574 | (KnownType::Integer(_), MetaType::Int)
2576 | (KnownType::Integer(_), MetaType::Number)
2577 => {
2578 let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2579
2580 Ok((new, vec![otherid.clone(), ukid.clone()]))
2581 },
2582 (KnownType::Integer(val), MetaType::Uint)
2584 => {
2585 if val < &0.to_bigint() {
2586 Err(meta_err_producer!())
2587 } else {
2588 let new = self.add_type_var(TypeVar::Known(*loc, base.clone(), new_params));
2589
2590 Ok((new, vec![otherid.clone(), ukid.clone()]))
2591 }
2592 },
2593
2594 (KnownType::Integer(_), MetaType::Type) => Err(meta_err_producer!()),
2596
2597 (_, MetaType::Bool) => Err(meta_err_producer!()),
2599 (KnownType::Bool(_), _) => Err(meta_err_producer!()),
2600
2601 (KnownType::Named(_), MetaType::Int | MetaType::Number | MetaType::Uint)
2603 | (KnownType::Tuple, MetaType::Int | MetaType::Uint | MetaType::Number)
2604 | (KnownType::Array, MetaType::Int | MetaType::Uint | MetaType::Number)
2605 | (KnownType::Wire, MetaType::Int | MetaType::Uint | MetaType::Number)
2606 | (KnownType::Inverted, MetaType::Int | MetaType::Uint | MetaType::Number)
2607 => Err(meta_err_producer!())
2608 }
2609 }
2610 };
2611
2612 let (new_type, replaced_types) = result?;
2613
2614 self.trace_stack.push(TraceStackEntry::Unified(
2615 v1.debug_resolve(self),
2616 v2.debug_resolve(self),
2617 new_type.debug_resolve(self),
2618 replaced_types
2619 .iter()
2620 .map(|v| v.debug_resolve(self))
2621 .collect(),
2622 ));
2623
2624 for replaced_type in &replaced_types {
2625 if v1.inner != v2.inner {
2626 let (from, to) = (replaced_type.get_type(self), new_type.get_type(self));
2627 self.replacements.insert(from, to);
2628 if let Err(rec) = self.check_type_for_recursion(to, &mut vec![]) {
2629 let err_t = self.t_err(().nowhere());
2630 self.replacements.insert(to, err_t);
2631 return Err(UnificationError::RecursiveType(rec));
2632 }
2633 }
2634 }
2635
2636 Ok(new_type)
2637 }
2638
2639 pub fn can_unify(&mut self, e1: &impl HasType, e2: &impl HasType, ctx: &Context) -> bool {
2640 self.trace_stack
2641 .push(TraceStackEntry::Enter("Running can_unify".to_string()));
2642 let result = self.do_and_restore(|s| s.unify(e1, e2, ctx)).is_ok();
2643 self.trace_stack.push(TraceStackEntry::Exit);
2644 result
2645 }
2646
2647 #[tracing::instrument(level = "trace", skip_all)]
2648 pub fn unify(
2649 &mut self,
2650 e1: &impl HasType,
2651 e2: &impl HasType,
2652 ctx: &Context,
2653 ) -> std::result::Result<TypeVarID, UnificationError> {
2654 let new_type = self.unify_inner(e1, e2, ctx)?;
2655
2656 loop {
2659 trace!("Updating constraints");
2660 let new_info;
2662 (self.constraints, new_info) = self
2663 .constraints
2664 .clone()
2665 .update_type_level_value_constraints(self);
2666
2667 if new_info.is_empty() {
2668 break;
2669 }
2670
2671 for constraint in new_info {
2672 trace!(
2673 "Constraint replaces {} with {:?}",
2674 constraint.inner.0.display(self),
2675 constraint.inner.1
2676 );
2677
2678 let ((var, replacement), loc) = constraint.split_loc();
2679
2680 self.trace_stack
2681 .push(TraceStackEntry::InferringFromConstraints(
2682 var.debug_resolve(self),
2683 replacement.val.clone(),
2684 ));
2685
2686 let expected_type = self.add_type_var(TypeVar::Known(loc, replacement.val, vec![]));
2688 let result = self.unify_inner(&expected_type.clone().at_loc(&loc), &var, ctx);
2689 let is_meta_error = matches!(result, Err(UnificationError::MetaMismatch { .. }));
2690 match result {
2691 Ok(_) => {}
2692 Err(UnificationError::Normal(Tm { mut e, mut g }))
2693 | Err(UnificationError::MetaMismatch(Tm { mut e, mut g })) => {
2694 e.inside.replace(
2695 replacement
2696 .context
2697 .inside
2698 .replace_inside(var, e.failing, self),
2699 );
2700 g.inside.replace(
2701 replacement
2702 .context
2703 .inside
2704 .replace_inside(var, g.failing, self),
2705 );
2706 return Err(UnificationError::FromConstraints {
2707 got: g,
2708 expected: e,
2709 source: replacement.context.source,
2710 loc,
2711 is_meta_error,
2712 });
2713 }
2714 Err(
2715 e @ UnificationError::FromConstraints { .. }
2716 | e @ UnificationError::Specific { .. }
2717 | e @ UnificationError::RecursiveType(_)
2718 | e @ UnificationError::UnsatisfiedTraits { .. },
2719 ) => return Err(e),
2720 };
2721 }
2722 }
2723
2724 Ok(new_type)
2725 }
2726
2727 fn check_type_for_recursion(
2728 &self,
2729 ty: TypeVarID,
2730 seen: &mut Vec<TypeVarID>,
2731 ) -> std::result::Result<(), String> {
2732 seen.push(ty);
2733 match ty.resolve(self) {
2734 TypeVar::Known(_, base, params) => {
2735 for (i, param) in params.iter().enumerate() {
2736 if seen.contains(param) {
2737 return Err("*".to_string());
2738 }
2739
2740 if let Err(rest) = self.check_type_for_recursion(*param, seen) {
2741 let list = params
2742 .iter()
2743 .enumerate()
2744 .map(|(j, _)| {
2745 if j == i {
2746 rest.clone()
2747 } else {
2748 "_".to_string()
2749 }
2750 })
2751 .join(", ");
2752
2753 match base {
2754 KnownType::Error => {}
2755 KnownType::Named(name_id) => {
2756 return Err(format!("{name_id}<{}>", list));
2757 }
2758 KnownType::Bool(_) | KnownType::Integer(_) => {
2759 unreachable!("Encountered recursive type level bool or int")
2760 }
2761 KnownType::Tuple => return Err(format!("({})", list)),
2762 KnownType::Array => return Err(format!("[{}]", list)),
2763 KnownType::Wire => return Err(format!("&{}", list)),
2764 KnownType::Inverted => return Err(format!("inv {}", list)),
2765 }
2766 }
2767 }
2768 }
2769 TypeVar::Unknown(_, _, traits, _) => {
2770 for t in &traits.inner {
2771 for param in &t.type_params {
2772 if seen.contains(param) {
2773 return Err("...".to_string());
2774 }
2775
2776 self.check_type_for_recursion(*param, seen)?;
2777 }
2778 }
2779 }
2780 }
2781 seen.pop();
2782 Ok(())
2783 }
2784
2785 fn ensure_impls(
2786 &mut self,
2787 var: &TypeVarID,
2788 traits: &TraitList,
2789 trait_is_expected: bool,
2790 trait_list_loc: &Loc<()>,
2791 ctx: &Context,
2792 ) -> std::result::Result<Vec<(TraitImpl, TraitReq)>, UnificationError> {
2793 self.trace_stack.push(TraceStackEntry::EnsuringImpls(
2794 var.debug_resolve(self),
2795 traits.clone(),
2796 trait_is_expected,
2797 ));
2798
2799 let number = ctx
2800 .symtab
2801 .lookup_trait(&Path::from_strs(&["Number"]).nowhere())
2802 .expect("Did not find number in symtab")
2803 .0;
2804
2805 macro_rules! error_producer {
2806 ($required_traits:expr) => {
2807 if trait_is_expected {
2808 if $required_traits.inner.len() == 1
2809 && $required_traits
2810 .get_trait(&TraitName::Named(number.clone().nowhere()))
2811 .is_some()
2812 {
2813 Err(UnificationError::Normal(Tm {
2814 e: UnificationTrace::new(
2815 self.new_generic_with_traits(*trait_list_loc, $required_traits),
2816 ),
2817 g: UnificationTrace::new(var.clone()),
2818 }))
2819 } else {
2820 Err(UnificationError::UnsatisfiedTraits {
2821 var: *var,
2822 traits: $required_traits.inner,
2823 target_loc: trait_list_loc.clone(),
2824 })
2825 }
2826 } else {
2827 Err(UnificationError::Normal(Tm {
2828 e: UnificationTrace::new(var.clone()),
2829 g: UnificationTrace::new(
2830 self.new_generic_with_traits(*trait_list_loc, $required_traits),
2831 ),
2832 }))
2833 }
2834 };
2835 }
2836
2837 match &var.resolve(self).clone() {
2838 TypeVar::Known(_, known, params) if known.into_impl_target().is_some() => {
2839 let Some(target) = known.into_impl_target() else {
2840 unreachable!()
2841 };
2842
2843 let (impls, unsatisfied): (Vec<_>, Vec<_>) = traits
2844 .inner
2845 .iter()
2846 .map(|trait_req| {
2847 if let Some(impld) = self.trait_impls.inner.get(&target).cloned() {
2848 let target_impls = impld
2851 .iter()
2852 .filter_map(|trait_impl| {
2853 self.checkpoint();
2854 let trait_params_match = trait_impl
2855 .trait_type_params
2856 .iter()
2857 .zip(trait_req.type_params.iter())
2858 .all(|(l, r)| {
2859 let l = l.make_copy(self);
2860 self.unify(&l, r, ctx).is_ok()
2861 });
2862
2863 let impl_params_match =
2864 trait_impl.target_type_params.iter().zip(params).all(
2865 |(l, r)| {
2866 let l = l.make_copy(self);
2867 self.unify(&l, r, ctx).is_ok()
2868 },
2869 );
2870 self.restore();
2871
2872 if trait_impl.name == trait_req.name
2873 && trait_params_match
2874 && impl_params_match
2875 {
2876 Some(trait_impl)
2877 } else {
2878 None
2879 }
2880 })
2881 .collect::<Vec<_>>();
2882
2883 if target_impls.len() == 0 {
2884 Ok(Either::Right(trait_req.clone()))
2885 } else if target_impls.len() == 1 {
2886 let target_impl = *target_impls.last().unwrap();
2887 Ok(Either::Left((target_impl.clone(), trait_req.inner.clone())))
2888 } else {
2889 Err(UnificationError::Specific(diag_anyhow!(
2890 trait_req,
2891 "Found more than one impl of {} for {}",
2892 trait_req.display(self),
2893 var.display(self)
2894 )))
2895 }
2896 } else {
2897 Ok(Either::Right(trait_req.clone()))
2898 }
2899 })
2900 .collect::<std::result::Result<Vec<_>, _>>()?
2901 .into_iter()
2902 .partition_map(|x| x);
2903
2904 if unsatisfied.is_empty() {
2905 self.trace_stack.push(TraceStackEntry::Message(
2906 "Ensuring impl successful".to_string(),
2907 ));
2908 Ok(impls)
2909 } else {
2910 error_producer!(TraitList::from_vec(unsatisfied.clone()))
2911 }
2912 }
2913 TypeVar::Unknown(_, _, _, _) => {
2914 panic!("running ensure_impls on an unknown type")
2915 }
2916 _ => {
2917 if traits.inner.is_empty() {
2918 Ok(vec![])
2919 } else {
2920 error_producer!(traits.clone())
2921 }
2922 }
2923 }
2924 }
2925
2926 pub fn unify_expression_generic_error(
2927 &mut self,
2928 expr: &Loc<Expression>,
2929 other: &impl HasType,
2930 ctx: &Context,
2931 ) -> Result<TypeVarID> {
2932 self.unify(&expr.inner, other, ctx)
2933 .into_default_diagnostic(expr.loc(), self)
2934 }
2935
2936 pub fn check_requirements(&mut self, is_final_check: bool, ctx: &Context) -> Result<()> {
2937 loop {
2939 let (retain, replacements_option): (Vec<_>, Vec<_>) = self
2943 .requirements
2944 .clone()
2945 .iter()
2946 .map(|req| match req.check(self, ctx)? {
2947 requirements::RequirementResult::NoChange => Ok((true, None)),
2948 requirements::RequirementResult::UnsatisfiedNow(diag) => {
2949 if is_final_check {
2950 Err(diag)
2951 } else {
2952 Ok((true, None))
2953 }
2954 }
2955 requirements::RequirementResult::Satisfied(replacement) => {
2956 self.trace_stack
2957 .push(TraceStackEntry::ResolvedRequirement(req.clone()));
2958 Ok((false, Some(replacement)))
2959 }
2960 })
2961 .collect::<Result<Vec<_>>>()?
2962 .into_iter()
2963 .unzip();
2964
2965 let replacements = replacements_option
2966 .into_iter()
2967 .flatten()
2968 .flatten()
2969 .collect::<Vec<_>>();
2970
2971 self.requirements = self
2973 .requirements
2974 .drain(0..)
2975 .zip(retain)
2976 .filter_map(|(req, keep)| if keep { Some(req) } else { None })
2977 .collect();
2978
2979 if replacements.is_empty() {
2980 break;
2981 }
2982
2983 for Replacement { from, to, context } in replacements {
2984 self.unify(&to, &from, ctx).into_diagnostic_or_default(
2985 from.loc(),
2986 context,
2987 self,
2988 )?;
2989 }
2990 }
2991
2992 Ok(())
2993 }
2994
2995 pub fn get_replacement(&self, var: &TypeVarID) -> TypeVarID {
2996 self.replacements.get(*var)
2997 }
2998
2999 pub fn do_and_restore<T, E>(
3000 &mut self,
3001 inner: impl Fn(&mut Self) -> std::result::Result<T, E>,
3002 ) -> std::result::Result<T, E> {
3003 self.checkpoint();
3004 let result = inner(self);
3005 self.restore();
3006 result
3007 }
3008
3009 fn checkpoint(&mut self) {
3012 self.trace_stack
3013 .push(TraceStackEntry::Enter("Creating checkpoint".to_string()));
3014 self.replacements.push();
3015 }
3016
3017 fn restore(&mut self) {
3018 self.replacements.discard_top();
3019 self.trace_stack.push(TraceStackEntry::Exit);
3020 }
3021}
3022
3023impl TypeState {
3024 pub fn print_equations(&self) {
3025 for (lhs, rhs) in &self.equations {
3026 println!(
3027 "{} -> {}",
3028 format!("{lhs}").blue(),
3029 format!("{}", rhs.debug_resolve(self)).green()
3030 )
3031 }
3032
3033 println!("\nReplacments:");
3034
3035 for repl_stack in &self.replacements.all() {
3036 let replacements = { repl_stack.borrow().clone() };
3037 for (lhs, rhs) in replacements.iter().sorted() {
3038 println!(
3039 "{} -> {} ({} -> {})",
3040 format!("{}", lhs.inner).blue(),
3041 format!("{}", rhs.inner).green(),
3042 format!("{}", lhs.debug_resolve(self)).blue(),
3043 format!("{}", rhs.debug_resolve(self)).green(),
3044 )
3045 }
3046 println!("---")
3047 }
3048
3049 println!("\n Requirements:");
3050
3051 for requirement in &self.requirements {
3052 println!("{:?}", requirement)
3053 }
3054
3055 println!()
3056 }
3057}
3058
3059#[must_use]
3060pub struct UnificationBuilder {
3061 lhs: TypeVarID,
3062 rhs: TypeVarID,
3063}
3064impl UnificationBuilder {
3065 pub fn commit(
3066 self,
3067 state: &mut TypeState,
3068 ctx: &Context,
3069 ) -> std::result::Result<TypeVarID, UnificationError> {
3070 state.unify(&self.lhs, &self.rhs, ctx)
3071 }
3072}
3073
3074pub trait HasType: std::fmt::Debug {
3075 fn get_type(&self, state: &TypeState) -> TypeVarID {
3076 self.try_get_type(state)
3077 .expect(&format!("Did not find a type for {self:?}"))
3078 }
3079
3080 fn try_get_type(&self, state: &TypeState) -> Option<TypeVarID> {
3081 let id = self.get_type_impl(state);
3082 id.map(|id| state.get_replacement(&id))
3083 }
3084
3085 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID>;
3086
3087 fn unify_with(&self, rhs: &dyn HasType, state: &TypeState) -> UnificationBuilder {
3088 UnificationBuilder {
3089 lhs: self.get_type(state),
3090 rhs: rhs.get_type(state),
3091 }
3092 }
3093}
3094
3095impl HasType for TypeVarID {
3096 fn get_type_impl(&self, _state: &TypeState) -> Option<TypeVarID> {
3097 Some(*self)
3098 }
3099}
3100impl HasType for Loc<TypeVarID> {
3101 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3102 self.inner.try_get_type(state)
3103 }
3104}
3105impl HasType for TypedExpression {
3106 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3107 state.maybe_type_of(self).cloned()
3108 }
3109}
3110impl HasType for Expression {
3111 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3112 state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3113 }
3114}
3115impl HasType for Loc<Expression> {
3116 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3117 state
3118 .maybe_type_of(&TypedExpression::Id(self.inner.id))
3119 .cloned()
3120 }
3121}
3122impl HasType for Pattern {
3123 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3124 state.maybe_type_of(&TypedExpression::Id(self.id)).cloned()
3125 }
3126}
3127impl HasType for Loc<Pattern> {
3128 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3129 state
3130 .maybe_type_of(&TypedExpression::Id(self.inner.id))
3131 .cloned()
3132 }
3133}
3134impl HasType for NameID {
3135 fn get_type_impl(&self, state: &TypeState) -> Option<TypeVarID> {
3136 state
3137 .maybe_type_of(&TypedExpression::Name(self.clone()))
3138 .cloned()
3139 }
3140}