spade_typeinference/
equation.rs

1use itertools::Itertools;
2use std::collections::{BTreeMap, HashMap};
3
4use num::BigInt;
5use serde::{Deserialize, Serialize};
6use spade_common::{
7    id_tracker::ExprID,
8    location_info::{Loc, WithLocation},
9    name::NameID,
10};
11use spade_types::{meta_types::MetaType, KnownType};
12
13use crate::{
14    traits::{TraitList, TraitReq},
15    HasType, TypeState,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
19pub struct TypeVarID {
20    pub inner: usize,
21    /// Key from the TypeState from which this var originates. See the `key` field
22    /// of the type state for details
23    pub type_state_key: u64,
24}
25impl WithLocation for TypeVarID {}
26
27impl TypeVarID {
28    pub fn resolve(self, state: &TypeState) -> &TypeVar {
29        assert!(
30            state.keys.contains(&self.type_state_key),
31            "Type var key mismatch. Type states are being mixed incorrectly. Type state has {:?}, var has {}", state.keys, self.type_state_key
32        );
33        // In case our ID is stale, we'll need to look up the final ID
34        let final_id = self.get_type(state);
35        state.type_vars.get(final_id.inner).unwrap()
36    }
37
38    pub fn replace_inside(
39        self,
40        from: TypeVarID,
41        to: TypeVarID,
42        state: &mut TypeState,
43    ) -> TypeVarID {
44        if self.get_type(state) == from.get_type(state) {
45            to
46        } else {
47            let mut new = self.resolve(state).clone();
48            match &mut new {
49                TypeVar::Known(_, _known_type, params) => {
50                    params
51                        .iter_mut()
52                        .for_each(|var| *var = var.replace_inside(from, to, state));
53                }
54                TypeVar::Unknown(_, _, trait_list, _) => {
55                    trait_list.inner.iter_mut().for_each(|var| {
56                        let TraitReq {
57                            name: _,
58                            type_params,
59                        } = &mut var.inner;
60
61                        type_params
62                            .iter_mut()
63                            .for_each(|t| *t = t.replace_inside(from, to, state));
64                    })
65                }
66            };
67
68            // NOTE: For performance we could consider not doing replacement if none
69            // of the inner types changed. For now, we only use this in diagnostics,
70            // so even for performance, it won't make a difference
71            new.insert(state)
72        }
73    }
74
75    pub fn display(self, type_state: &TypeState) -> String {
76        self.display_with_meta(false, type_state)
77    }
78
79    pub fn display_with_meta(self, meta: bool, type_state: &TypeState) -> String {
80        match self.resolve(type_state) {
81            TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
82            TypeVar::Known(_, KnownType::Named(t), params) => {
83                let generics = if params.is_empty() {
84                    String::new()
85                } else {
86                    format!(
87                        "<{}>",
88                        params
89                            .iter()
90                            .map(|p| format!("{}", p.display_with_meta(meta, type_state)))
91                            .collect::<Vec<_>>()
92                            .join(", ")
93                    )
94                };
95                format!("{}{}", t, generics)
96            }
97            TypeVar::Known(_, KnownType::Integer(inner), _) => {
98                format!("{inner}")
99            }
100            TypeVar::Known(_, KnownType::Bool(inner), _) => {
101                format!("{inner}")
102            }
103            TypeVar::Known(_, KnownType::Tuple, params) => {
104                format!(
105                    "({})",
106                    params
107                        .iter()
108                        .map(|t| format!("{}", t.display_with_meta(meta, type_state)))
109                        .collect::<Vec<_>>()
110                        .join(", ")
111                )
112            }
113            TypeVar::Known(_, KnownType::Array, params) => {
114                format!(
115                    "[{}; {}]",
116                    params[0].display_with_meta(meta, type_state),
117                    params[1].display_with_meta(meta, type_state)
118                )
119            }
120            TypeVar::Known(_, KnownType::Wire, params) => {
121                format!("&{}", params[0].display_with_meta(meta, type_state))
122            }
123            TypeVar::Known(_, KnownType::Inverted, params) => {
124                format!("inv {}", params[0].display_with_meta(meta, type_state))
125            }
126            TypeVar::Unknown(_, _, traits, meta_type) => match meta_type {
127                MetaType::Type => {
128                    if !traits.inner.is_empty() {
129                        traits
130                            .inner
131                            .iter()
132                            .map(|t| t.display_with_meta(meta, type_state))
133                            .join(" + ")
134                    } else {
135                        "_".to_string()
136                    }
137                }
138                _ => {
139                    if meta {
140                        format!("{}", meta_type)
141                    } else {
142                        format!("_")
143                    }
144                }
145            },
146        }
147    }
148
149    pub fn debug_resolve(self, state: &TypeState) -> TypeVarString {
150        match self.resolve(state) {
151            TypeVar::Known(_, base, params) => {
152                let params = if params.is_empty() {
153                    "".to_string()
154                } else {
155                    format!(
156                        "<{}>",
157                        params.iter().map(|t| t.debug_resolve(state).0).join(", ")
158                    )
159                };
160                let base = match base {
161                    KnownType::Named(name_id) => format!("{name_id}"),
162                    KnownType::Integer(big_int) => format!("{big_int}"),
163                    KnownType::Bool(val) => format!("{val}"),
164                    KnownType::Tuple => format!("Tuple"),
165                    KnownType::Array => format!("Array"),
166                    KnownType::Wire => format!("&"),
167                    KnownType::Inverted => format!("inv &"),
168                    KnownType::Error => format!("{{error}}"),
169                };
170                TypeVarString(format!("{base}{params}"), self)
171            }
172            TypeVar::Unknown(_, id, traits, _meta_type) => {
173                let traits = if traits.inner.is_empty() {
174                    "".to_string()
175                } else {
176                    format!(
177                        "({})",
178                        traits
179                            .inner
180                            .iter()
181                            .map(|t| t.debug_display(state))
182                            .join(" + ")
183                    )
184                };
185                TypeVarString(format!("t{id}{traits}"), self)
186            }
187        }
188    }
189}
190
191/// A type which which should not be resolved directly but can be used to create new
192/// copies with unique type var ids
193#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
194pub struct TemplateTypeVarID {
195    inner: TypeVarID,
196}
197
198impl TemplateTypeVarID {
199    pub fn new(inner: TypeVarID) -> Self {
200        Self { inner }
201    }
202
203    pub fn make_copy(&self, state: &mut TypeState) -> TypeVarID {
204        self.make_copy_with_mapping(state, &mut BTreeMap::new())
205    }
206
207    pub fn make_copy_with_mapping(
208        &self,
209        state: &mut TypeState,
210        mapped: &mut BTreeMap<TemplateTypeVarID, TypeVarID>,
211    ) -> TypeVarID {
212        if let Some(prev) = mapped.get(self) {
213            return *prev;
214        }
215
216        let new = match self.inner.resolve(state).clone() {
217            TypeVar::Known(loc, base, params) => TypeVar::Known(
218                loc,
219                base,
220                params
221                    .into_iter()
222                    .map(|p| TemplateTypeVarID { inner: p }.make_copy_with_mapping(state, mapped))
223                    .collect(),
224            ),
225            TypeVar::Unknown(loc, id, TraitList { inner: tl }, meta_type) => TypeVar::Unknown(
226                loc,
227                id,
228                TraitList {
229                    inner: tl
230                        .into_iter()
231                        .map(|loc| {
232                            loc.map(|req| TraitReq {
233                                name: req.name,
234                                type_params: req
235                                    .type_params
236                                    .into_iter()
237                                    .map(|p| {
238                                        TemplateTypeVarID { inner: p }
239                                            .make_copy_with_mapping(state, mapped)
240                                    })
241                                    .collect(),
242                            })
243                        })
244                        .collect(),
245                },
246                meta_type,
247            ),
248        };
249        let result = state.add_type_var(new);
250        mapped.insert(*self, result);
251        result
252    }
253}
254
255pub type TypeEquations = HashMap<TypedExpression, TypeVarID>;
256
257/// A frozen TypeVar that can be printed for debugging purposes
258#[derive(Debug, Clone)]
259pub struct TypeVarString(pub String, pub TypeVarID);
260
261impl std::fmt::Display for TypeVarString {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        write!(f, "{}", self.0)
264    }
265}
266
267/// A type variable represents the type of something in the program. It is mapped
268/// to expressions by type equations in the TypeState.
269///
270/// When TypeVars are passed externally into TypeState, they must be checked for replacement,
271/// as the type inference process might have refined them.
272#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Debug)]
273pub enum TypeVar {
274    /// The base type is known and has a list of parameters
275    Known(Loc<()>, KnownType, Vec<TypeVarID>),
276    /// The type is unknown, but must satisfy the specified traits. When the generic substitution
277    /// is done, the TypeVars will be carried over to the KnownType type vars
278    Unknown(Loc<()>, u64, TraitList, MetaType),
279}
280
281impl WithLocation for TypeVar {}
282
283impl TypeVar {
284    pub fn into_known(&self, type_state: &TypeState) -> Option<KnownTypeVar> {
285        match self {
286            TypeVar::Known(loc, base, params) => Some(KnownTypeVar(
287                loc.clone(),
288                base.clone(),
289                params
290                    .iter()
291                    .map(|t| t.resolve(type_state).into_known(type_state))
292                    .collect::<Option<_>>()?,
293            )),
294            TypeVar::Unknown(_, _, _, _) => None,
295        }
296    }
297
298    pub fn insert(self, into: &mut TypeState) -> TypeVarID {
299        into.add_type_var(self)
300    }
301
302    pub fn array(loc: Loc<()>, inner: TypeVarID, size: TypeVarID) -> Self {
303        TypeVar::Known(loc, KnownType::Array, vec![inner, size])
304    }
305
306    pub fn tuple(loc: Loc<()>, inner: Vec<TypeVarID>) -> Self {
307        TypeVar::Known(loc, KnownType::Tuple, inner)
308    }
309
310    pub fn unit(loc: Loc<()>) -> Self {
311        TypeVar::Known(loc, KnownType::Tuple, Vec::new())
312    }
313
314    pub fn wire(loc: Loc<()>, inner: TypeVarID) -> Self {
315        TypeVar::Known(loc, KnownType::Wire, vec![inner])
316    }
317
318    pub fn backward(loc: Loc<()>, inner: TypeVarID, type_state: &mut TypeState) -> Self {
319        TypeVar::Known(
320            loc,
321            KnownType::Inverted,
322            vec![type_state.add_type_var(TypeVar::Known(loc, KnownType::Wire, vec![inner]))],
323        )
324    }
325
326    pub fn inverted(loc: Loc<()>, inner: TypeVarID) -> Self {
327        TypeVar::Known(loc, KnownType::Inverted, vec![inner])
328    }
329
330    pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
331    where
332        U: FnOnce() -> T,
333        K: FnOnce(&KnownType, &[TypeVarID]) -> T,
334    {
335        match self {
336            TypeVar::Unknown(_, _, _, _) => on_unknown(),
337            TypeVar::Known(_, k, v) => on_known(k, v),
338        }
339    }
340
341    pub fn expect_named<T, E, U, K, O>(
342        &self,
343        on_named: K,
344        on_unknown: U,
345        on_other: O,
346        on_error: E,
347    ) -> T
348    where
349        U: FnOnce() -> T,
350        K: FnOnce(&NameID, &[TypeVarID]) -> T,
351        E: FnOnce() -> T,
352        O: FnOnce(&TypeVar) -> T,
353    {
354        match self {
355            TypeVar::Unknown(_, _, _, _) => on_unknown(),
356            TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
357            TypeVar::Known(_, KnownType::Error, _) => on_error(),
358            other => on_other(other),
359        }
360    }
361
362    /// Expect a named type, or TypeVar::Inverted(named), or a recursive inversion.
363    /// inverted_now is stateful and used to track if we are currently in an
364    /// inverted context.
365    pub fn resolve_named_or_inverted(
366        &self,
367        inverted_now: bool,
368        type_state: &TypeState,
369    ) -> ResolvedNamedOrInverted {
370        match self {
371            TypeVar::Unknown(_, _, _, _) => ResolvedNamedOrInverted::Unknown,
372            TypeVar::Known(_, KnownType::Inverted, params) => {
373                if params.len() != 1 {
374                    panic!("Found wire with {} params", params.len())
375                }
376                params[0]
377                    .resolve(type_state)
378                    .resolve_named_or_inverted(!inverted_now, type_state)
379            }
380            TypeVar::Known(_, KnownType::Named(name), params) => {
381                ResolvedNamedOrInverted::Named(inverted_now, name.clone(), params.clone())
382            }
383            _ => ResolvedNamedOrInverted::Other,
384        }
385    }
386
387    pub fn expect_specific_named<T, U, K, O>(
388        &self,
389        name: NameID,
390        on_correct: K,
391        on_unknown: U,
392        on_other: O,
393    ) -> T
394    where
395        U: FnOnce() -> T,
396        K: FnOnce(&[TypeVarID]) -> T,
397        O: FnOnce(&TypeVar) -> T,
398    {
399        match self {
400            TypeVar::Unknown(_, _, _, _) => on_unknown(),
401            TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
402            other => on_other(other),
403        }
404    }
405
406    /// Assumes that this type is KnownType::Integer(size) and calls on_integer then. Otherwise
407    /// calls on_unknown or on_other depending on the type. If the integer is given type params,
408    /// panics
409    pub fn expect_integer<T, E, U, K, O>(
410        &self,
411        on_integer: K,
412        on_unknown: U,
413        on_other: O,
414        on_error: E,
415    ) -> T
416    where
417        U: FnOnce() -> T,
418        E: FnOnce() -> T,
419        K: FnOnce(BigInt) -> T,
420        O: FnOnce(&TypeVar) -> T,
421    {
422        match self {
423            TypeVar::Known(_, KnownType::Integer(size), params) => {
424                assert!(params.is_empty());
425                on_integer(size.clone())
426            }
427            TypeVar::Known(_, KnownType::Error, _) => on_error(),
428            TypeVar::Unknown(_, _, _, _) => on_unknown(),
429            other => on_other(other),
430        }
431    }
432
433    pub fn display(&self, type_state: &TypeState) -> String {
434        self.display_with_meta(false, type_state)
435    }
436
437    pub fn display_with_meta(&self, display_meta: bool, type_state: &TypeState) -> String {
438        match self {
439            TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
440            TypeVar::Known(_, KnownType::Named(t), params) => {
441                let generics = if params.is_empty() {
442                    String::new()
443                } else {
444                    format!(
445                        "<{}>",
446                        params
447                            .iter()
448                            .map(|p| format!("{}", p.display_with_meta(display_meta, type_state)))
449                            .collect::<Vec<_>>()
450                            .join(", ")
451                    )
452                };
453                format!("{}{}", t, generics)
454            }
455            TypeVar::Known(_, KnownType::Integer(inner), _) => {
456                format!("{inner}")
457            }
458            TypeVar::Known(_, KnownType::Bool(inner), _) => {
459                format!("{inner}")
460            }
461            TypeVar::Known(_, KnownType::Tuple, params) => {
462                format!(
463                    "({})",
464                    params
465                        .iter()
466                        .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
467                        .collect::<Vec<_>>()
468                        .join(", ")
469                )
470            }
471            TypeVar::Known(_, KnownType::Array, params) => {
472                format!(
473                    "[{}; {}]",
474                    params[0].display_with_meta(display_meta, type_state),
475                    params[1].display_with_meta(display_meta, type_state)
476                )
477            }
478            TypeVar::Known(_, KnownType::Wire, params) => {
479                format!("&{}", params[0].display_with_meta(display_meta, type_state))
480            }
481            TypeVar::Known(_, KnownType::Inverted, params) => {
482                format!(
483                    "inv {}",
484                    params[0].display_with_meta(display_meta, type_state)
485                )
486            }
487            TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
488                if display_meta {
489                    format!("{meta}")
490                } else {
491                    "_".to_string()
492                }
493            }
494            // If we have traits, we know this is a type
495            TypeVar::Unknown(_, _, traits, _meta) => {
496                format!(
497                    "{}",
498                    traits
499                        .inner
500                        .iter()
501                        .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
502                        .join("+"),
503                )
504            }
505        }
506    }
507}
508
509#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
510pub struct KnownTypeVar(pub Loc<()>, pub KnownType, pub Vec<KnownTypeVar>);
511
512impl KnownTypeVar {
513    pub fn insert(&self, type_state: &mut TypeState) -> TypeVarID {
514        let KnownTypeVar(loc, base, params) = self;
515        TypeVar::Known(
516            loc.clone(),
517            base.clone(),
518            params.into_iter().map(|p| p.insert(type_state)).collect(),
519        )
520        .insert(type_state)
521    }
522}
523
524impl std::fmt::Display for KnownTypeVar {
525    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        let KnownTypeVar(_, base, params) = self;
527
528        match base {
529            KnownType::Error => {
530                write!(f, "{{unknown}}")
531            }
532            KnownType::Named(name_id) => {
533                write!(f, "{name_id}")?;
534                if !params.is_empty() {
535                    write!(f, "<{}>", params.iter().map(|t| format!("{t}")).join(", "))?;
536                }
537                Ok(())
538            }
539            KnownType::Integer(val) => write!(f, "{val}"),
540            KnownType::Bool(val) => write!(f, "{val}"),
541            KnownType::Tuple => {
542                write!(f, "({})", params.iter().map(|t| format!("{t}")).join(", "))
543            }
544            KnownType::Array => write!(f, "[{}; {}]", params[0], params[1]),
545            KnownType::Wire => write!(f, "&{}", params[0]),
546            KnownType::Inverted => write!(f, "inv {}", params[0]),
547        }
548    }
549}
550
551pub enum ResolvedNamedOrInverted {
552    Unknown,
553    Named(bool, NameID, Vec<TypeVarID>),
554    Other,
555}
556
557#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
558pub enum TypedExpression {
559    Id(ExprID),
560    Name(NameID),
561}
562
563impl WithLocation for TypedExpression {}
564
565impl std::fmt::Display for TypedExpression {
566    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567        match self {
568            TypedExpression::Id(i) => {
569                write!(f, "%{}", i.0)
570            }
571            TypedExpression::Name(p) => {
572                write!(f, "{}#{}", p, p.0)
573            }
574        }
575    }
576}