spade_typeinference/
equation.rs

1use std::collections::BTreeMap;
2
3use itertools::Itertools;
4type HashMap<K, V> =
5    imbl::GenericHashMap<K, V, rustc_hash::FxBuildHasher, imbl::shared_ptr::DefaultSharedPtr>;
6
7use num::BigInt;
8use serde::{Deserialize, Serialize};
9use spade_common::{id_tracker::ExprID, location_info::Loc, name::NameID};
10use spade_types::{meta_types::MetaType, KnownType};
11
12use crate::{
13    traits::{TraitList, TraitReq},
14    HasType, TypeState,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
18pub struct TypeVarID {
19    pub inner: usize,
20    /// Key from the TypeState from which this var originates. See the `key` field
21    /// of the type state for details
22    pub type_state_key: u64,
23}
24
25impl TypeVarID {
26    pub fn resolve(self, state: &TypeState) -> TypeVar {
27        assert!(
28            state.owned.keys.contains(&self.type_state_key),
29            "Type var key mismatch. Type states are being mixed incorrectly. Type state has {:?}, var has {}", state.owned.keys, self.type_state_key
30        );
31        // In case our ID is stale, we'll need to look up the final ID
32        let final_id = self.get_type(state);
33        state
34            .shared
35            .read_type_vars(|type_vars| type_vars.get(final_id.inner).unwrap().clone())
36    }
37
38    pub fn replace_inside(
39        self,
40        from: TypeVarID,
41        to: TypeVarID,
42        state: &mut TypeState,
43    ) -> TypeVarID {
44        if self.get_type(state) == from.get_type(state) {
45            to
46        } else {
47            let mut new = self.resolve(state).clone();
48            match &mut new {
49                TypeVar::Known(_, _known_type, params) => {
50                    params
51                        .iter_mut()
52                        .for_each(|var| *var = var.replace_inside(from, to, state));
53                }
54                TypeVar::Unknown(_, _, trait_list, _) => {
55                    trait_list.inner.iter_mut().for_each(|var| {
56                        let TraitReq {
57                            name: _,
58                            type_params,
59                        } = &mut var.inner;
60
61                        type_params
62                            .iter_mut()
63                            .for_each(|t| *t = t.replace_inside(from, to, state));
64                    })
65                }
66            };
67
68            // NOTE: For performance we could consider not doing replacement if none
69            // of the inner types changed. For now, we only use this in diagnostics,
70            // so even for performance, it won't make a difference
71            new.insert(state)
72        }
73    }
74
75    pub fn display(self, type_state: &TypeState) -> String {
76        self.display_with_meta(false, type_state)
77    }
78
79    pub fn display_with_meta(self, meta: bool, type_state: &TypeState) -> String {
80        match self.resolve(type_state) {
81            TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
82            TypeVar::Known(_, KnownType::Named(t), params) => {
83                let generics = if params.is_empty() {
84                    String::new()
85                } else {
86                    format!(
87                        "<{}>",
88                        params
89                            .iter()
90                            .map(|p| format!("{}", p.display_with_meta(meta, type_state)))
91                            .collect::<Vec<_>>()
92                            .join(", ")
93                    )
94                };
95                format!("{}{}", t, generics)
96            }
97            TypeVar::Known(_, KnownType::Integer(inner), _) => {
98                format!("{inner}")
99            }
100            TypeVar::Known(_, KnownType::Bool(inner), _) => {
101                format!("{inner}")
102            }
103            TypeVar::Known(_, KnownType::String(inner), _) => {
104                format!("{inner:?}")
105            }
106            TypeVar::Known(_, KnownType::Tuple, params) => {
107                format!(
108                    "({})",
109                    params
110                        .iter()
111                        .map(|t| format!("{}", t.display_with_meta(meta, type_state)))
112                        .collect::<Vec<_>>()
113                        .join(", ")
114                )
115            }
116            TypeVar::Known(_, KnownType::Array, params) => {
117                format!(
118                    "[{}; {}]",
119                    params[0].display_with_meta(meta, type_state),
120                    params[1].display_with_meta(meta, type_state)
121                )
122            }
123            TypeVar::Known(_, KnownType::Wire, params) => {
124                format!("&{}", params[0].display_with_meta(meta, type_state))
125            }
126            TypeVar::Known(_, KnownType::Inverted, params) => {
127                format!("inv {}", params[0].display_with_meta(meta, type_state))
128            }
129            TypeVar::Unknown(_, _, traits, meta_type) => match meta_type {
130                MetaType::Type => {
131                    if !traits.inner.is_empty() {
132                        traits
133                            .inner
134                            .iter()
135                            .map(|t| t.display_with_meta(meta, type_state))
136                            .join(" + ")
137                    } else {
138                        "_".to_string()
139                    }
140                }
141                _ => {
142                    if meta {
143                        format!("{}", meta_type)
144                    } else {
145                        format!("_")
146                    }
147                }
148            },
149        }
150    }
151
152    pub fn debug_resolve(self, state: &TypeState) -> TypeVarString {
153        match self.resolve(state) {
154            TypeVar::Known(_, base, params) => {
155                let params = if params.is_empty() {
156                    "".to_string()
157                } else {
158                    format!(
159                        "<{}>",
160                        params.iter().map(|t| t.debug_resolve(state).0).join(", ")
161                    )
162                };
163                let base = match base {
164                    KnownType::Named(name_id) => format!("{name_id}"),
165                    KnownType::Integer(big_int) => format!("{big_int}"),
166                    KnownType::Bool(val) => format!("{val}"),
167                    KnownType::String(val) => format!("{val:?}"),
168                    KnownType::Tuple => format!("Tuple"),
169                    KnownType::Array => format!("Array"),
170                    KnownType::Wire => format!("&"),
171                    KnownType::Inverted => format!("inv &"),
172                    KnownType::Error => format!("{{error}}"),
173                };
174                TypeVarString(format!("{base}{params}"), self)
175            }
176            TypeVar::Unknown(_, id, traits, _meta_type) => {
177                let traits = if traits.inner.is_empty() {
178                    "".to_string()
179                } else {
180                    format!(
181                        "({})",
182                        traits
183                            .inner
184                            .iter()
185                            .map(|t| t.debug_display(state))
186                            .join(" + ")
187                    )
188                };
189                TypeVarString(format!("t{id}{traits}"), self)
190            }
191        }
192    }
193}
194
195/// A type which which should not be resolved directly but can be used to create new
196/// copies with unique type var ids
197#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
198pub struct TemplateTypeVarID {
199    inner: TypeVarID,
200}
201
202impl TemplateTypeVarID {
203    pub fn new(inner: TypeVarID) -> Self {
204        Self { inner }
205    }
206
207    pub fn make_copy(&self, state: &mut TypeState) -> TypeVarID {
208        self.make_copy_with_mapping(state, &mut BTreeMap::new())
209    }
210
211    pub fn make_copy_with_mapping(
212        &self,
213        state: &mut TypeState,
214        mapped: &mut BTreeMap<TemplateTypeVarID, TypeVarID>,
215    ) -> TypeVarID {
216        if let Some(prev) = mapped.get(self) {
217            return *prev;
218        }
219
220        let new = match self.inner.resolve(state).clone() {
221            TypeVar::Known(loc, base, params) => TypeVar::Known(
222                loc,
223                base,
224                params
225                    .into_iter()
226                    .map(|p| TemplateTypeVarID { inner: p }.make_copy_with_mapping(state, mapped))
227                    .collect(),
228            ),
229            TypeVar::Unknown(loc, id, TraitList { inner: tl }, meta_type) => TypeVar::Unknown(
230                loc,
231                id,
232                TraitList {
233                    inner: tl
234                        .into_iter()
235                        .map(|loc| {
236                            loc.map(|req| TraitReq {
237                                name: req.name,
238                                type_params: req
239                                    .type_params
240                                    .into_iter()
241                                    .map(|p| {
242                                        TemplateTypeVarID { inner: p }
243                                            .make_copy_with_mapping(state, mapped)
244                                    })
245                                    .collect(),
246                            })
247                        })
248                        .collect(),
249                },
250                meta_type,
251            ),
252        };
253        let result = state.add_type_var(new);
254        mapped.insert(*self, result);
255        result
256    }
257}
258
259pub type TypeEquations = HashMap<TypedExpression, TypeVarID>;
260
261/// A frozen TypeVar that can be printed for debugging purposes
262#[derive(Debug, Clone)]
263pub struct TypeVarString(pub String, pub TypeVarID);
264
265impl std::fmt::Display for TypeVarString {
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        write!(f, "{}", self.0)
268    }
269}
270
271/// A type variable represents the type of something in the program. It is mapped
272/// to expressions by type equations in the TypeState.
273///
274/// When TypeVars are passed externally into TypeState, they must be checked for replacement,
275/// as the type inference process might have refined them.
276#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Debug)]
277pub enum TypeVar {
278    /// The base type is known and has a list of parameters
279    Known(Loc<()>, KnownType, Vec<TypeVarID>),
280    /// The type is unknown, but must satisfy the specified traits. When the generic substitution
281    /// is done, the TypeVars will be carried over to the KnownType type vars
282    Unknown(Loc<()>, u64, TraitList, MetaType),
283}
284
285impl TypeVar {
286    pub fn into_known(&self, type_state: &TypeState) -> Option<KnownTypeVar> {
287        match self {
288            TypeVar::Known(loc, base, params) => Some(KnownTypeVar(
289                loc.clone(),
290                base.clone(),
291                params
292                    .iter()
293                    .map(|t| t.resolve(type_state).into_known(type_state))
294                    .collect::<Option<_>>()?,
295            )),
296            TypeVar::Unknown(_, _, _, _) => None,
297        }
298    }
299
300    pub fn insert(self, into: &mut TypeState) -> TypeVarID {
301        into.add_type_var(self)
302    }
303
304    pub fn array(loc: Loc<()>, inner: TypeVarID, size: TypeVarID) -> Self {
305        TypeVar::Known(loc, KnownType::Array, vec![inner, size])
306    }
307
308    pub fn tuple(loc: Loc<()>, inner: Vec<TypeVarID>) -> Self {
309        TypeVar::Known(loc, KnownType::Tuple, inner)
310    }
311
312    pub fn unit(loc: Loc<()>) -> Self {
313        TypeVar::Known(loc, KnownType::Tuple, Vec::new())
314    }
315
316    pub fn wire(loc: Loc<()>, inner: TypeVarID) -> Self {
317        TypeVar::Known(loc, KnownType::Wire, vec![inner])
318    }
319
320    pub fn backward(loc: Loc<()>, inner: TypeVarID, type_state: &mut TypeState) -> Self {
321        TypeVar::Known(
322            loc,
323            KnownType::Inverted,
324            vec![type_state.add_type_var(TypeVar::Known(loc, KnownType::Wire, vec![inner]))],
325        )
326    }
327
328    pub fn inverted(loc: Loc<()>, inner: TypeVarID) -> Self {
329        TypeVar::Known(loc, KnownType::Inverted, vec![inner])
330    }
331
332    pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
333    where
334        U: FnOnce() -> T,
335        K: FnOnce(&KnownType, &[TypeVarID]) -> T,
336    {
337        match self {
338            TypeVar::Unknown(_, _, _, _) => on_unknown(),
339            TypeVar::Known(_, k, v) => on_known(k, v),
340        }
341    }
342
343    pub fn expect_named<T, E, U, K, O>(
344        &self,
345        on_named: K,
346        on_unknown: U,
347        on_other: O,
348        on_error: E,
349    ) -> T
350    where
351        U: FnOnce() -> T,
352        K: FnOnce(&NameID, &[TypeVarID]) -> T,
353        E: FnOnce() -> T,
354        O: FnOnce(&TypeVar) -> T,
355    {
356        match self {
357            TypeVar::Unknown(_, _, _, _) => on_unknown(),
358            TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
359            TypeVar::Known(_, KnownType::Error, _) => on_error(),
360            other => on_other(other),
361        }
362    }
363
364    /// Expect a named type, or TypeVar::Inverted(named), or a recursive inversion.
365    /// inverted_now is stateful and used to track if we are currently in an
366    /// inverted context.
367    pub fn resolve_named_or_inverted(
368        &self,
369        inverted_now: bool,
370        type_state: &TypeState,
371    ) -> ResolvedNamedOrInverted {
372        match self {
373            TypeVar::Unknown(_, _, _, _) => ResolvedNamedOrInverted::Unknown,
374            TypeVar::Known(_, KnownType::Inverted, params) => {
375                if params.len() != 1 {
376                    panic!("Found wire with {} params", params.len())
377                }
378                params[0]
379                    .resolve(type_state)
380                    .resolve_named_or_inverted(!inverted_now, type_state)
381            }
382            TypeVar::Known(_, KnownType::Named(name), params) => {
383                ResolvedNamedOrInverted::Named(inverted_now, name.clone(), params.clone())
384            }
385            _ => ResolvedNamedOrInverted::Other,
386        }
387    }
388
389    pub fn expect_specific_named<T, U, K, O>(
390        &self,
391        name: NameID,
392        on_correct: K,
393        on_unknown: U,
394        on_other: O,
395    ) -> T
396    where
397        U: FnOnce() -> T,
398        K: FnOnce(&[TypeVarID]) -> T,
399        O: FnOnce(&TypeVar) -> T,
400    {
401        match self {
402            TypeVar::Unknown(_, _, _, _) => on_unknown(),
403            TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
404            other => on_other(other),
405        }
406    }
407
408    /// Assumes that this type is KnownType::Integer(size) and calls on_integer then. Otherwise
409    /// calls on_unknown or on_other depending on the type. If the integer is given type params,
410    /// panics
411    pub fn expect_integer<T, E, U, K, O>(
412        &self,
413        on_integer: K,
414        on_unknown: U,
415        on_other: O,
416        on_error: E,
417    ) -> T
418    where
419        U: FnOnce() -> T,
420        E: FnOnce() -> T,
421        K: FnOnce(BigInt) -> T,
422        O: FnOnce(&TypeVar) -> T,
423    {
424        match self {
425            TypeVar::Known(_, KnownType::Integer(size), params) => {
426                assert!(params.is_empty());
427                on_integer(size.clone())
428            }
429            TypeVar::Known(_, KnownType::Error, _) => on_error(),
430            TypeVar::Unknown(_, _, _, _) => on_unknown(),
431            other => on_other(other),
432        }
433    }
434
435    /// Assumes that this type is KnownType::String(val) and calls on_string then. Otherwise
436    /// calls on_unknown or on_other depending on the type. If the integer is given type params,
437    /// panics
438    pub fn expect_string<T, E, U, K, O>(
439        &self,
440        on_string: K,
441        on_unknown: U,
442        on_other: O,
443        on_error: E,
444    ) -> T
445    where
446        U: FnOnce() -> T,
447        E: FnOnce() -> T,
448        K: FnOnce(String) -> T,
449        O: FnOnce(&TypeVar) -> T,
450    {
451        match self {
452            TypeVar::Known(_, KnownType::String(val), params) => {
453                assert!(params.is_empty());
454                on_string(val.clone())
455            }
456            TypeVar::Known(_, KnownType::Error, _) => on_error(),
457            TypeVar::Unknown(_, _, _, _) => on_unknown(),
458            other => on_other(other),
459        }
460    }
461
462    pub fn display(&self, type_state: &TypeState) -> String {
463        self.display_with_meta(false, type_state)
464    }
465
466    pub fn display_with_meta(&self, display_meta: bool, type_state: &TypeState) -> String {
467        match self {
468            TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
469            TypeVar::Known(_, KnownType::Named(t), params) => {
470                let generics = if params.is_empty() {
471                    String::new()
472                } else {
473                    format!(
474                        "<{}>",
475                        params
476                            .iter()
477                            .map(|p| format!("{}", p.display_with_meta(display_meta, type_state)))
478                            .collect::<Vec<_>>()
479                            .join(", ")
480                    )
481                };
482                format!("{}{}", t, generics)
483            }
484            TypeVar::Known(_, KnownType::Integer(inner), _) => {
485                format!("{inner}")
486            }
487            TypeVar::Known(_, KnownType::Bool(inner), _) => {
488                format!("{inner}")
489            }
490            TypeVar::Known(_, KnownType::String(inner), _) => {
491                format!("{inner:?}")
492            }
493            TypeVar::Known(_, KnownType::Tuple, params) => {
494                format!(
495                    "({})",
496                    params
497                        .iter()
498                        .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
499                        .collect::<Vec<_>>()
500                        .join(", ")
501                )
502            }
503            TypeVar::Known(_, KnownType::Array, params) => {
504                format!(
505                    "[{}; {}]",
506                    params[0].display_with_meta(display_meta, type_state),
507                    params[1].display_with_meta(display_meta, type_state)
508                )
509            }
510            TypeVar::Known(_, KnownType::Wire, params) => {
511                format!("&{}", params[0].display_with_meta(display_meta, type_state))
512            }
513            TypeVar::Known(_, KnownType::Inverted, params) => {
514                format!(
515                    "inv {}",
516                    params[0].display_with_meta(display_meta, type_state)
517                )
518            }
519            TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
520                if display_meta {
521                    format!("{meta}")
522                } else {
523                    "_".to_string()
524                }
525            }
526            // If we have traits, we know this is a type
527            TypeVar::Unknown(_, _, traits, _meta) => {
528                format!(
529                    "{}",
530                    traits
531                        .inner
532                        .iter()
533                        .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
534                        .join("+"),
535                )
536            }
537        }
538    }
539}
540
541#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
542pub struct KnownTypeVar(pub Loc<()>, pub KnownType, pub Vec<KnownTypeVar>);
543
544impl KnownTypeVar {
545    pub fn insert(&self, type_state: &mut TypeState) -> TypeVarID {
546        let KnownTypeVar(loc, base, params) = self;
547        TypeVar::Known(
548            loc.clone(),
549            base.clone(),
550            params.into_iter().map(|p| p.insert(type_state)).collect(),
551        )
552        .insert(type_state)
553    }
554}
555
556impl std::fmt::Display for KnownTypeVar {
557    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
558        let KnownTypeVar(_, base, params) = self;
559
560        match base {
561            KnownType::Error => {
562                write!(f, "{{unknown}}")
563            }
564            KnownType::Named(name_id) => {
565                write!(f, "{name_id}")?;
566                if !params.is_empty() {
567                    write!(f, "<{}>", params.iter().map(|t| format!("{t}")).join(", "))?;
568                }
569                Ok(())
570            }
571            KnownType::Integer(val) => write!(f, "{val}"),
572            KnownType::Bool(val) => write!(f, "{val}"),
573            KnownType::String(val) => write!(f, "{val:?}"),
574            KnownType::Tuple => {
575                write!(f, "({})", params.iter().map(|t| format!("{t}")).join(", "))
576            }
577            KnownType::Array => write!(f, "[{}; {}]", params[0], params[1]),
578            KnownType::Wire => write!(f, "&{}", params[0]),
579            KnownType::Inverted => write!(f, "inv {}", params[0]),
580        }
581    }
582}
583
584pub enum ResolvedNamedOrInverted {
585    Unknown,
586    Named(bool, NameID, Vec<TypeVarID>),
587    Other,
588}
589
590#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
591pub enum TypedExpression {
592    Id(ExprID),
593    Name(NameID),
594}
595
596impl std::fmt::Display for TypedExpression {
597    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598        match self {
599            TypedExpression::Id(i) => {
600                write!(f, "%{}", i.0)
601            }
602            TypedExpression::Name(p) => {
603                write!(f, "{}#{}", p, p.0)
604            }
605        }
606    }
607}