spade_typeinference/
equation.rs

1use itertools::Itertools;
2use std::collections::{BTreeSet, HashMap};
3
4use num::BigInt;
5use serde::{Deserialize, Serialize};
6use spade_common::{
7    id_tracker::ExprID,
8    location_info::{Loc, WithLocation},
9    name::NameID,
10};
11use spade_hir::TraitName;
12use spade_types::{meta_types::MetaType, KnownType};
13
14pub type TypeEquations = HashMap<TypedExpression, TypeVar>;
15
16#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub struct TraitReq {
18    pub name: TraitName,
19    pub type_params: Vec<TypeVar>,
20}
21
22impl WithLocation for TraitReq {}
23
24impl TraitReq {
25    pub fn display_with_meta(&self, display_meta: bool) -> String {
26        if self.type_params.is_empty() {
27            format!("{}", self.name)
28        } else {
29            format!(
30                "{}<{}>",
31                self.name,
32                self.type_params
33                    .iter()
34                    .map(|t| format!("{}", t.display_with_meta(display_meta)))
35                    .join(", ")
36            )
37        }
38    }
39}
40
41impl std::fmt::Display for TraitReq {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(f, "{}", self.display_with_meta(false))
44    }
45}
46impl std::fmt::Debug for TraitReq {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        if self.type_params.is_empty() {
49            write!(f, "{}", self.name)
50        } else {
51            write!(
52                f,
53                "{}<{}>",
54                self.name,
55                self.type_params.iter().map(|t| format!("{t:?}")).join(", ")
56            )
57        }
58    }
59}
60
61#[derive(Clone, Serialize, Deserialize)]
62pub struct TraitList {
63    pub inner: Vec<Loc<TraitReq>>,
64}
65
66impl TraitList {
67    pub fn empty() -> Self {
68        Self { inner: vec![] }
69    }
70
71    pub fn from_vec(inner: Vec<Loc<TraitReq>>) -> Self {
72        Self { inner }
73    }
74
75    pub fn get_trait(&self, name: &TraitName) -> Option<&Loc<TraitReq>> {
76        self.inner.iter().find(|t| &t.name == name)
77    }
78
79    pub fn get_trait_with_type_params(
80        &self,
81        name: &TraitName,
82        type_params: &[TypeVar],
83    ) -> Option<&Loc<TraitReq>> {
84        self.inner
85            .iter()
86            .find(|t| &t.name == name && &t.type_params.as_slice() == &type_params)
87    }
88
89    pub fn extend(self, other: Self) -> Self {
90        let merged = self
91            .inner
92            .into_iter()
93            .chain(other.inner.into_iter())
94            .collect::<BTreeSet<_>>()
95            .into_iter()
96            .collect_vec();
97
98        TraitList { inner: merged }
99    }
100
101    pub fn display_with_meta(&self, display_meta: bool) -> String {
102        self.inner
103            .iter()
104            .map(|t| t.inner.display_with_meta(display_meta))
105            .join(" + ")
106    }
107}
108
109// NOTE: The trait information is currently carried along with the type vars, but
110// the trait information should not be involved in comparisons
111impl PartialEq for TraitList {
112    fn eq(&self, _other: &Self) -> bool {
113        true
114    }
115}
116impl Eq for TraitList {}
117impl std::hash::Hash for TraitList {
118    fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
119}
120impl PartialOrd for TraitList {
121    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122        Some(self.cmp(other))
123    }
124}
125impl Ord for TraitList {
126    fn cmp(&self, _other: &Self) -> std::cmp::Ordering {
127        std::cmp::Ordering::Equal
128    }
129}
130
131impl std::fmt::Display for TraitList {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(f, "{}", self.display_with_meta(false))
134    }
135}
136impl std::fmt::Debug for TraitList {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "{}", self.display_with_meta(true))
139    }
140}
141
142/// A type variable represents the type of something in the program. It is mapped
143/// to expressions by type equations in the TypeState.
144///
145/// When TypeVars are passed externally into TypeState, they must be checked for replacement,
146/// as the type inference process might have refined them.
147#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
148pub enum TypeVar {
149    /// The base type is known and has a list of parameters
150    Known(Loc<()>, KnownType, Vec<TypeVar>),
151    /// The type is unknown, but must satisfy the specified traits. When the generic substitution
152    /// is done, the TypeVars will be carried over to the KnownType type vars
153    Unknown(Loc<()>, u64, TraitList, MetaType),
154}
155
156impl WithLocation for TypeVar {}
157
158impl TypeVar {
159    pub fn array(loc: Loc<()>, inner: TypeVar, size: TypeVar) -> Self {
160        TypeVar::Known(loc, KnownType::Array, vec![inner, size])
161    }
162
163    pub fn tuple(loc: Loc<()>, inner: Vec<TypeVar>) -> Self {
164        TypeVar::Known(loc, KnownType::Tuple, inner)
165    }
166
167    pub fn unit(loc: Loc<()>) -> Self {
168        TypeVar::Known(loc, KnownType::Tuple, Vec::new())
169    }
170
171    pub fn wire(loc: Loc<()>, inner: TypeVar) -> Self {
172        TypeVar::Known(loc, KnownType::Wire, vec![inner])
173    }
174
175    pub fn backward(loc: Loc<()>, inner: TypeVar) -> Self {
176        TypeVar::Known(
177            loc,
178            KnownType::Inverted,
179            vec![TypeVar::Known(loc, KnownType::Wire, vec![inner])],
180        )
181    }
182
183    pub fn inverted(loc: Loc<()>, inner: TypeVar) -> Self {
184        TypeVar::Known(loc, KnownType::Inverted, vec![inner])
185    }
186
187    pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
188    where
189        U: FnOnce() -> T,
190        K: FnOnce(&KnownType, &[TypeVar]) -> T,
191    {
192        match self {
193            TypeVar::Unknown(_, _, _, _) => on_unknown(),
194            TypeVar::Known(_, k, v) => on_known(k, v),
195        }
196    }
197
198    pub fn expect_named<T, U, K, O>(&self, on_named: K, on_unknown: U, on_other: O) -> T
199    where
200        U: FnOnce() -> T,
201        K: FnOnce(&NameID, &[TypeVar]) -> T,
202        O: FnOnce(&TypeVar) -> T,
203    {
204        match self {
205            TypeVar::Unknown(_, _, _, _) => on_unknown(),
206            TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
207            other => on_other(other),
208        }
209    }
210
211    /// Expect a named type, or TypeVar::Inverted(named), or a recursive inversion.
212    /// inverted_now is stateful and used to track if we are currently in an
213    /// inverted context.
214    /// The first argument of the on_named function specifies whether or not
215    /// the final named type we found was inverted or not. I.e. if we ran it on
216    /// `inv T`, it would be called with (true, T), and if we ran it on `T` it would
217    /// be called with `(false, T)`
218    pub fn expect_named_or_inverted<T, U, K, O>(
219        &self,
220        inverted_now: bool,
221        on_named: K,
222        on_unknown: U,
223        on_other: O,
224    ) -> T
225    where
226        U: FnOnce() -> T,
227        K: FnOnce(bool, &NameID, &[TypeVar]) -> T,
228        O: FnOnce(&TypeVar) -> T,
229    {
230        match self {
231            TypeVar::Unknown(_, _, _, _) => on_unknown(),
232            TypeVar::Known(_, KnownType::Inverted, params) => {
233                if params.len() != 1 {
234                    panic!("Found wire with {} params", params.len())
235                }
236                params[0].expect_named_or_inverted(!inverted_now, on_named, on_unknown, on_other)
237            }
238            TypeVar::Known(_, KnownType::Named(name), params) => {
239                on_named(inverted_now, name, params)
240            }
241            other => on_other(other),
242        }
243    }
244
245    pub fn expect_specific_named<T, U, K, O>(
246        &self,
247        name: NameID,
248        on_correct: K,
249        on_unknown: U,
250        on_other: O,
251    ) -> T
252    where
253        U: FnOnce() -> T,
254        K: FnOnce(&[TypeVar]) -> T,
255        O: FnOnce(&TypeVar) -> T,
256    {
257        match self {
258            TypeVar::Unknown(_, _, _, _) => on_unknown(),
259            TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
260            other => on_other(other),
261        }
262    }
263
264    /// Assumes that this type is KnownType::Integer(size) and calls on_integer then. Otherwise
265    /// calls on_unknown or on_other depending on the type. If the integer is given type params,
266    /// panics
267    pub fn expect_integer<T, U, K, O>(&self, on_integer: K, on_unknown: U, on_other: O) -> T
268    where
269        U: FnOnce() -> T,
270        K: FnOnce(BigInt) -> T,
271        O: FnOnce(&TypeVar) -> T,
272    {
273        match self {
274            TypeVar::Known(_, KnownType::Integer(size), params) => {
275                assert!(params.is_empty());
276                on_integer(size.clone())
277            }
278            TypeVar::Unknown(_, _, _, _) => on_unknown(),
279            other => on_other(other),
280        }
281    }
282
283    pub fn display_with_meta(&self, display_meta: bool) -> String {
284        match self {
285            TypeVar::Known(_, KnownType::Named(t), params) => {
286                let generics = if params.is_empty() {
287                    String::new()
288                } else {
289                    format!(
290                        "<{}>",
291                        params
292                            .iter()
293                            .map(|p| format!("{}", p.display_with_meta(display_meta)))
294                            .collect::<Vec<_>>()
295                            .join(", ")
296                    )
297                };
298                format!("{}{}", t, generics)
299            }
300            TypeVar::Known(_, KnownType::Integer(inner), _) => {
301                format!("{inner}")
302            }
303            TypeVar::Known(_, KnownType::Bool(inner), _) => {
304                format!("{inner}")
305            }
306            TypeVar::Known(_, KnownType::Tuple, params) => {
307                format!(
308                    "({})",
309                    params
310                        .iter()
311                        .map(|t| format!("{}", t.display_with_meta(display_meta)))
312                        .collect::<Vec<_>>()
313                        .join(", ")
314                )
315            }
316            TypeVar::Known(_, KnownType::Array, params) => {
317                format!(
318                    "[{}; {}]",
319                    params[0].display_with_meta(display_meta),
320                    params[1].display_with_meta(display_meta)
321                )
322            }
323            TypeVar::Known(_, KnownType::Wire, params) => {
324                format!("&{}", params[0].display_with_meta(display_meta))
325            }
326            TypeVar::Known(_, KnownType::Inverted, params) => {
327                format!("inv {}", params[0].display_with_meta(display_meta))
328            }
329            TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
330                if display_meta {
331                    format!("{meta}")
332                } else {
333                    "_".to_string()
334                }
335            }
336            // If we have traits, we know this is a type
337            TypeVar::Unknown(_, _, traits, _meta) => {
338                format!(
339                    "{}",
340                    traits
341                        .inner
342                        .iter()
343                        .map(|t| format!("{}", t.display_with_meta(display_meta)))
344                        .join("+"),
345                )
346            }
347        }
348    }
349}
350
351impl std::fmt::Debug for TypeVar {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        match self {
354            TypeVar::Known(_, KnownType::Named(t), params) => {
355                let generics = if params.is_empty() {
356                    String::new()
357                } else {
358                    format!(
359                        "<{}>",
360                        params
361                            .iter()
362                            .map(|p| format!("{:?}", p))
363                            .collect::<Vec<_>>()
364                            .join(", ")
365                    )
366                };
367                write!(f, "{}{}", t, generics)
368            }
369            TypeVar::Known(_, KnownType::Integer(inner), _) => {
370                write!(f, "{inner}")
371            }
372            TypeVar::Known(_, KnownType::Bool(inner), _) => {
373                write!(f, "{inner}")
374            }
375            TypeVar::Known(_, KnownType::Tuple, params) => {
376                write!(
377                    f,
378                    "({})",
379                    params
380                        .iter()
381                        .map(|t| format!("{:?}", t))
382                        .collect::<Vec<_>>()
383                        .join(", ")
384                )
385            }
386            TypeVar::Known(_, KnownType::Array, params) => {
387                write!(f, "[{:?}; {:?}]", params[0], params[1])
388            }
389            TypeVar::Known(_, KnownType::Wire, params) => write!(f, "&{:?}", params[0]),
390            TypeVar::Known(_, KnownType::Inverted, params) => write!(f, "inv {:?}", params[0]),
391            TypeVar::Unknown(_, id, traits, meta_type) => write!(
392                f,
393                "t{id}({}, {meta_type})",
394                traits.inner.iter().map(|t| format!("{t}")).join("+")
395            ),
396        }
397    }
398}
399
400impl std::fmt::Display for TypeVar {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        write!(f, "{}", self.display_with_meta(false))
403    }
404}
405
406#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
407pub enum TypedExpression {
408    Id(ExprID),
409    Name(NameID),
410}
411
412impl WithLocation for TypedExpression {}
413
414impl std::fmt::Display for TypedExpression {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        match self {
417            TypedExpression::Id(i) => {
418                write!(f, "%{}", i.0)
419            }
420            TypedExpression::Name(p) => {
421                write!(f, "{}#{}", p, p.0)
422            }
423        }
424    }
425}