spade_typeinference/
equation.rs

1use itertools::Itertools;
2use std::collections::{BTreeMap, HashMap};
3
4use num::BigInt;
5use serde::{Deserialize, Serialize};
6use spade_common::{id_tracker::ExprID, location_info::Loc, name::NameID};
7use spade_types::{meta_types::MetaType, KnownType};
8
9use crate::{
10    traits::{TraitList, TraitReq},
11    HasType, TypeState,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
15pub struct TypeVarID {
16    pub inner: usize,
17    /// Key from the TypeState from which this var originates. See the `key` field
18    /// of the type state for details
19    pub type_state_key: u64,
20}
21
22impl TypeVarID {
23    pub fn resolve(self, state: &TypeState) -> &TypeVar {
24        assert!(
25            state.keys.contains(&self.type_state_key),
26            "Type var key mismatch. Type states are being mixed incorrectly. Type state has {:?}, var has {}", state.keys, self.type_state_key
27        );
28        // In case our ID is stale, we'll need to look up the final ID
29        let final_id = self.get_type(state);
30        state.type_vars.get(final_id.inner).unwrap()
31    }
32
33    pub fn replace_inside(
34        self,
35        from: TypeVarID,
36        to: TypeVarID,
37        state: &mut TypeState,
38    ) -> TypeVarID {
39        if self.get_type(state) == from.get_type(state) {
40            to
41        } else {
42            let mut new = self.resolve(state).clone();
43            match &mut new {
44                TypeVar::Known(_, _known_type, params) => {
45                    params
46                        .iter_mut()
47                        .for_each(|var| *var = var.replace_inside(from, to, state));
48                }
49                TypeVar::Unknown(_, _, trait_list, _) => {
50                    trait_list.inner.iter_mut().for_each(|var| {
51                        let TraitReq {
52                            name: _,
53                            type_params,
54                        } = &mut var.inner;
55
56                        type_params
57                            .iter_mut()
58                            .for_each(|t| *t = t.replace_inside(from, to, state));
59                    })
60                }
61            };
62
63            // NOTE: For performance we could consider not doing replacement if none
64            // of the inner types changed. For now, we only use this in diagnostics,
65            // so even for performance, it won't make a difference
66            new.insert(state)
67        }
68    }
69
70    pub fn display(self, type_state: &TypeState) -> String {
71        self.display_with_meta(false, type_state)
72    }
73
74    pub fn display_with_meta(self, meta: bool, type_state: &TypeState) -> String {
75        match self.resolve(type_state) {
76            TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
77            TypeVar::Known(_, KnownType::Named(t), params) => {
78                let generics = if params.is_empty() {
79                    String::new()
80                } else {
81                    format!(
82                        "<{}>",
83                        params
84                            .iter()
85                            .map(|p| format!("{}", p.display_with_meta(meta, type_state)))
86                            .collect::<Vec<_>>()
87                            .join(", ")
88                    )
89                };
90                format!("{}{}", t, generics)
91            }
92            TypeVar::Known(_, KnownType::Integer(inner), _) => {
93                format!("{inner}")
94            }
95            TypeVar::Known(_, KnownType::Bool(inner), _) => {
96                format!("{inner}")
97            }
98            TypeVar::Known(_, KnownType::String(inner), _) => {
99                format!("{inner:?}")
100            }
101            TypeVar::Known(_, KnownType::Tuple, params) => {
102                format!(
103                    "({})",
104                    params
105                        .iter()
106                        .map(|t| format!("{}", t.display_with_meta(meta, type_state)))
107                        .collect::<Vec<_>>()
108                        .join(", ")
109                )
110            }
111            TypeVar::Known(_, KnownType::Array, params) => {
112                format!(
113                    "[{}; {}]",
114                    params[0].display_with_meta(meta, type_state),
115                    params[1].display_with_meta(meta, type_state)
116                )
117            }
118            TypeVar::Known(_, KnownType::Wire, params) => {
119                format!("&{}", params[0].display_with_meta(meta, type_state))
120            }
121            TypeVar::Known(_, KnownType::Inverted, params) => {
122                format!("inv {}", params[0].display_with_meta(meta, type_state))
123            }
124            TypeVar::Unknown(_, _, traits, meta_type) => match meta_type {
125                MetaType::Type => {
126                    if !traits.inner.is_empty() {
127                        traits
128                            .inner
129                            .iter()
130                            .map(|t| t.display_with_meta(meta, type_state))
131                            .join(" + ")
132                    } else {
133                        "_".to_string()
134                    }
135                }
136                _ => {
137                    if meta {
138                        format!("{}", meta_type)
139                    } else {
140                        format!("_")
141                    }
142                }
143            },
144        }
145    }
146
147    pub fn debug_resolve(self, state: &TypeState) -> TypeVarString {
148        match self.resolve(state) {
149            TypeVar::Known(_, base, params) => {
150                let params = if params.is_empty() {
151                    "".to_string()
152                } else {
153                    format!(
154                        "<{}>",
155                        params.iter().map(|t| t.debug_resolve(state).0).join(", ")
156                    )
157                };
158                let base = match base {
159                    KnownType::Named(name_id) => format!("{name_id}"),
160                    KnownType::Integer(big_int) => format!("{big_int}"),
161                    KnownType::Bool(val) => format!("{val}"),
162                    KnownType::String(val) => format!("{val:?}"),
163                    KnownType::Tuple => format!("Tuple"),
164                    KnownType::Array => format!("Array"),
165                    KnownType::Wire => format!("&"),
166                    KnownType::Inverted => format!("inv &"),
167                    KnownType::Error => format!("{{error}}"),
168                };
169                TypeVarString(format!("{base}{params}"), self)
170            }
171            TypeVar::Unknown(_, id, traits, _meta_type) => {
172                let traits = if traits.inner.is_empty() {
173                    "".to_string()
174                } else {
175                    format!(
176                        "({})",
177                        traits
178                            .inner
179                            .iter()
180                            .map(|t| t.debug_display(state))
181                            .join(" + ")
182                    )
183                };
184                TypeVarString(format!("t{id}{traits}"), self)
185            }
186        }
187    }
188}
189
190/// A type which which should not be resolved directly but can be used to create new
191/// copies with unique type var ids
192#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
193pub struct TemplateTypeVarID {
194    inner: TypeVarID,
195}
196
197impl TemplateTypeVarID {
198    pub fn new(inner: TypeVarID) -> Self {
199        Self { inner }
200    }
201
202    pub fn make_copy(&self, state: &mut TypeState) -> TypeVarID {
203        self.make_copy_with_mapping(state, &mut BTreeMap::new())
204    }
205
206    pub fn make_copy_with_mapping(
207        &self,
208        state: &mut TypeState,
209        mapped: &mut BTreeMap<TemplateTypeVarID, TypeVarID>,
210    ) -> TypeVarID {
211        if let Some(prev) = mapped.get(self) {
212            return *prev;
213        }
214
215        let new = match self.inner.resolve(state).clone() {
216            TypeVar::Known(loc, base, params) => TypeVar::Known(
217                loc,
218                base,
219                params
220                    .into_iter()
221                    .map(|p| TemplateTypeVarID { inner: p }.make_copy_with_mapping(state, mapped))
222                    .collect(),
223            ),
224            TypeVar::Unknown(loc, id, TraitList { inner: tl }, meta_type) => TypeVar::Unknown(
225                loc,
226                id,
227                TraitList {
228                    inner: tl
229                        .into_iter()
230                        .map(|loc| {
231                            loc.map(|req| TraitReq {
232                                name: req.name,
233                                type_params: req
234                                    .type_params
235                                    .into_iter()
236                                    .map(|p| {
237                                        TemplateTypeVarID { inner: p }
238                                            .make_copy_with_mapping(state, mapped)
239                                    })
240                                    .collect(),
241                            })
242                        })
243                        .collect(),
244                },
245                meta_type,
246            ),
247        };
248        let result = state.add_type_var(new);
249        mapped.insert(*self, result);
250        result
251    }
252}
253
254pub type TypeEquations = HashMap<TypedExpression, TypeVarID>;
255
256/// A frozen TypeVar that can be printed for debugging purposes
257#[derive(Debug, Clone)]
258pub struct TypeVarString(pub String, pub TypeVarID);
259
260impl std::fmt::Display for TypeVarString {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        write!(f, "{}", self.0)
263    }
264}
265
266/// A type variable represents the type of something in the program. It is mapped
267/// to expressions by type equations in the TypeState.
268///
269/// When TypeVars are passed externally into TypeState, they must be checked for replacement,
270/// as the type inference process might have refined them.
271#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Debug)]
272pub enum TypeVar {
273    /// The base type is known and has a list of parameters
274    Known(Loc<()>, KnownType, Vec<TypeVarID>),
275    /// The type is unknown, but must satisfy the specified traits. When the generic substitution
276    /// is done, the TypeVars will be carried over to the KnownType type vars
277    Unknown(Loc<()>, u64, TraitList, MetaType),
278}
279
280impl TypeVar {
281    pub fn into_known(&self, type_state: &TypeState) -> Option<KnownTypeVar> {
282        match self {
283            TypeVar::Known(loc, base, params) => Some(KnownTypeVar(
284                loc.clone(),
285                base.clone(),
286                params
287                    .iter()
288                    .map(|t| t.resolve(type_state).into_known(type_state))
289                    .collect::<Option<_>>()?,
290            )),
291            TypeVar::Unknown(_, _, _, _) => None,
292        }
293    }
294
295    pub fn insert(self, into: &mut TypeState) -> TypeVarID {
296        into.add_type_var(self)
297    }
298
299    pub fn array(loc: Loc<()>, inner: TypeVarID, size: TypeVarID) -> Self {
300        TypeVar::Known(loc, KnownType::Array, vec![inner, size])
301    }
302
303    pub fn tuple(loc: Loc<()>, inner: Vec<TypeVarID>) -> Self {
304        TypeVar::Known(loc, KnownType::Tuple, inner)
305    }
306
307    pub fn unit(loc: Loc<()>) -> Self {
308        TypeVar::Known(loc, KnownType::Tuple, Vec::new())
309    }
310
311    pub fn wire(loc: Loc<()>, inner: TypeVarID) -> Self {
312        TypeVar::Known(loc, KnownType::Wire, vec![inner])
313    }
314
315    pub fn backward(loc: Loc<()>, inner: TypeVarID, type_state: &mut TypeState) -> Self {
316        TypeVar::Known(
317            loc,
318            KnownType::Inverted,
319            vec![type_state.add_type_var(TypeVar::Known(loc, KnownType::Wire, vec![inner]))],
320        )
321    }
322
323    pub fn inverted(loc: Loc<()>, inner: TypeVarID) -> Self {
324        TypeVar::Known(loc, KnownType::Inverted, vec![inner])
325    }
326
327    pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
328    where
329        U: FnOnce() -> T,
330        K: FnOnce(&KnownType, &[TypeVarID]) -> T,
331    {
332        match self {
333            TypeVar::Unknown(_, _, _, _) => on_unknown(),
334            TypeVar::Known(_, k, v) => on_known(k, v),
335        }
336    }
337
338    pub fn expect_named<T, E, U, K, O>(
339        &self,
340        on_named: K,
341        on_unknown: U,
342        on_other: O,
343        on_error: E,
344    ) -> T
345    where
346        U: FnOnce() -> T,
347        K: FnOnce(&NameID, &[TypeVarID]) -> T,
348        E: FnOnce() -> T,
349        O: FnOnce(&TypeVar) -> T,
350    {
351        match self {
352            TypeVar::Unknown(_, _, _, _) => on_unknown(),
353            TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
354            TypeVar::Known(_, KnownType::Error, _) => on_error(),
355            other => on_other(other),
356        }
357    }
358
359    /// Expect a named type, or TypeVar::Inverted(named), or a recursive inversion.
360    /// inverted_now is stateful and used to track if we are currently in an
361    /// inverted context.
362    pub fn resolve_named_or_inverted(
363        &self,
364        inverted_now: bool,
365        type_state: &TypeState,
366    ) -> ResolvedNamedOrInverted {
367        match self {
368            TypeVar::Unknown(_, _, _, _) => ResolvedNamedOrInverted::Unknown,
369            TypeVar::Known(_, KnownType::Inverted, params) => {
370                if params.len() != 1 {
371                    panic!("Found wire with {} params", params.len())
372                }
373                params[0]
374                    .resolve(type_state)
375                    .resolve_named_or_inverted(!inverted_now, type_state)
376            }
377            TypeVar::Known(_, KnownType::Named(name), params) => {
378                ResolvedNamedOrInverted::Named(inverted_now, name.clone(), params.clone())
379            }
380            _ => ResolvedNamedOrInverted::Other,
381        }
382    }
383
384    pub fn expect_specific_named<T, U, K, O>(
385        &self,
386        name: NameID,
387        on_correct: K,
388        on_unknown: U,
389        on_other: O,
390    ) -> T
391    where
392        U: FnOnce() -> T,
393        K: FnOnce(&[TypeVarID]) -> T,
394        O: FnOnce(&TypeVar) -> T,
395    {
396        match self {
397            TypeVar::Unknown(_, _, _, _) => on_unknown(),
398            TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
399            other => on_other(other),
400        }
401    }
402
403    /// Assumes that this type is KnownType::Integer(size) and calls on_integer then. Otherwise
404    /// calls on_unknown or on_other depending on the type. If the integer is given type params,
405    /// panics
406    pub fn expect_integer<T, E, U, K, O>(
407        &self,
408        on_integer: K,
409        on_unknown: U,
410        on_other: O,
411        on_error: E,
412    ) -> T
413    where
414        U: FnOnce() -> T,
415        E: FnOnce() -> T,
416        K: FnOnce(BigInt) -> T,
417        O: FnOnce(&TypeVar) -> T,
418    {
419        match self {
420            TypeVar::Known(_, KnownType::Integer(size), params) => {
421                assert!(params.is_empty());
422                on_integer(size.clone())
423            }
424            TypeVar::Known(_, KnownType::Error, _) => on_error(),
425            TypeVar::Unknown(_, _, _, _) => on_unknown(),
426            other => on_other(other),
427        }
428    }
429
430    /// Assumes that this type is KnownType::String(val) and calls on_string then. Otherwise
431    /// calls on_unknown or on_other depending on the type. If the integer is given type params,
432    /// panics
433    pub fn expect_string<T, E, U, K, O>(
434        &self,
435        on_string: K,
436        on_unknown: U,
437        on_other: O,
438        on_error: E,
439    ) -> T
440    where
441        U: FnOnce() -> T,
442        E: FnOnce() -> T,
443        K: FnOnce(String) -> T,
444        O: FnOnce(&TypeVar) -> T,
445    {
446        match self {
447            TypeVar::Known(_, KnownType::String(val), params) => {
448                assert!(params.is_empty());
449                on_string(val.clone())
450            }
451            TypeVar::Known(_, KnownType::Error, _) => on_error(),
452            TypeVar::Unknown(_, _, _, _) => on_unknown(),
453            other => on_other(other),
454        }
455    }
456
457    pub fn display(&self, type_state: &TypeState) -> String {
458        self.display_with_meta(false, type_state)
459    }
460
461    pub fn display_with_meta(&self, display_meta: bool, type_state: &TypeState) -> String {
462        match self {
463            TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
464            TypeVar::Known(_, KnownType::Named(t), params) => {
465                let generics = if params.is_empty() {
466                    String::new()
467                } else {
468                    format!(
469                        "<{}>",
470                        params
471                            .iter()
472                            .map(|p| format!("{}", p.display_with_meta(display_meta, type_state)))
473                            .collect::<Vec<_>>()
474                            .join(", ")
475                    )
476                };
477                format!("{}{}", t, generics)
478            }
479            TypeVar::Known(_, KnownType::Integer(inner), _) => {
480                format!("{inner}")
481            }
482            TypeVar::Known(_, KnownType::Bool(inner), _) => {
483                format!("{inner}")
484            }
485            TypeVar::Known(_, KnownType::String(inner), _) => {
486                format!("{inner:?}")
487            }
488            TypeVar::Known(_, KnownType::Tuple, params) => {
489                format!(
490                    "({})",
491                    params
492                        .iter()
493                        .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
494                        .collect::<Vec<_>>()
495                        .join(", ")
496                )
497            }
498            TypeVar::Known(_, KnownType::Array, params) => {
499                format!(
500                    "[{}; {}]",
501                    params[0].display_with_meta(display_meta, type_state),
502                    params[1].display_with_meta(display_meta, type_state)
503                )
504            }
505            TypeVar::Known(_, KnownType::Wire, params) => {
506                format!("&{}", params[0].display_with_meta(display_meta, type_state))
507            }
508            TypeVar::Known(_, KnownType::Inverted, params) => {
509                format!(
510                    "inv {}",
511                    params[0].display_with_meta(display_meta, type_state)
512                )
513            }
514            TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
515                if display_meta {
516                    format!("{meta}")
517                } else {
518                    "_".to_string()
519                }
520            }
521            // If we have traits, we know this is a type
522            TypeVar::Unknown(_, _, traits, _meta) => {
523                format!(
524                    "{}",
525                    traits
526                        .inner
527                        .iter()
528                        .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
529                        .join("+"),
530                )
531            }
532        }
533    }
534}
535
536#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
537pub struct KnownTypeVar(pub Loc<()>, pub KnownType, pub Vec<KnownTypeVar>);
538
539impl KnownTypeVar {
540    pub fn insert(&self, type_state: &mut TypeState) -> TypeVarID {
541        let KnownTypeVar(loc, base, params) = self;
542        TypeVar::Known(
543            loc.clone(),
544            base.clone(),
545            params.into_iter().map(|p| p.insert(type_state)).collect(),
546        )
547        .insert(type_state)
548    }
549}
550
551impl std::fmt::Display for KnownTypeVar {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        let KnownTypeVar(_, base, params) = self;
554
555        match base {
556            KnownType::Error => {
557                write!(f, "{{unknown}}")
558            }
559            KnownType::Named(name_id) => {
560                write!(f, "{name_id}")?;
561                if !params.is_empty() {
562                    write!(f, "<{}>", params.iter().map(|t| format!("{t}")).join(", "))?;
563                }
564                Ok(())
565            }
566            KnownType::Integer(val) => write!(f, "{val}"),
567            KnownType::Bool(val) => write!(f, "{val}"),
568            KnownType::String(val) => write!(f, "{val:?}"),
569            KnownType::Tuple => {
570                write!(f, "({})", params.iter().map(|t| format!("{t}")).join(", "))
571            }
572            KnownType::Array => write!(f, "[{}; {}]", params[0], params[1]),
573            KnownType::Wire => write!(f, "&{}", params[0]),
574            KnownType::Inverted => write!(f, "inv {}", params[0]),
575        }
576    }
577}
578
579pub enum ResolvedNamedOrInverted {
580    Unknown,
581    Named(bool, NameID, Vec<TypeVarID>),
582    Other,
583}
584
585#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
586pub enum TypedExpression {
587    Id(ExprID),
588    Name(NameID),
589}
590
591impl std::fmt::Display for TypedExpression {
592    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        match self {
594            TypedExpression::Id(i) => {
595                write!(f, "%{}", i.0)
596            }
597            TypedExpression::Name(p) => {
598                write!(f, "{}#{}", p, p.0)
599            }
600        }
601    }
602}