spade_typeinference/
error.rs

1use itertools::Itertools;
2
3use spade_common::location_info::{FullSpan, Loc, WithLocation};
4use spade_diagnostics::Diagnostic;
5
6use crate::{constraints::ConstraintSource, equation::TypeVarID, traits::TraitReq, TypeState};
7
8use super::equation::TypeVar;
9
10/// A trace of a unification error. The `failing` field indicates which exact type failed to unify,
11/// while the `inside` is the "top level" type which failed to unify if it's not the same as
12/// failing.
13///
14/// For example, if unifying `int<7>` with `int<8>`, this would be `failing: 8, inside: int<8>`
15/// while if unifying `int<7>` with `bool`, inside would be `None`
16#[derive(Debug, PartialEq, Clone)]
17pub struct UnificationTrace {
18    pub failing: TypeVarID,
19    pub inside: Option<TypeVarID>,
20}
21impl WithLocation for UnificationTrace {}
22
23impl UnificationTrace {
24    pub fn new(failing: TypeVarID) -> Self {
25        Self {
26            failing,
27            inside: None,
28        }
29    }
30
31    pub fn outer(&self) -> TypeVarID {
32        *self.inside.as_ref().unwrap_or(&self.failing)
33    }
34
35    pub fn display_with_meta(&self, meta: bool, type_state: &TypeState) -> String {
36        self.inside
37            .as_ref()
38            .unwrap_or(&self.failing)
39            .display_with_meta(meta, type_state)
40    }
41
42    pub fn display(&self, type_state: &TypeState) -> String {
43        self.outer().display_with_meta(false, type_state)
44    }
45}
46
47pub trait UnificationErrorExt<T>: Sized {
48    fn add_context(
49        self,
50        got: TypeVarID,
51        expected: TypeVarID,
52    ) -> std::result::Result<T, UnificationError>;
53
54    /// Creates a diagnostic with a generic type mismatch error
55    fn into_default_diagnostic(
56        self,
57        unification_point: impl Into<FullSpan> + Clone,
58        type_state: &TypeState,
59    ) -> std::result::Result<T, Diagnostic> {
60        self.into_diagnostic(unification_point, |d, _| d, type_state)
61    }
62
63    /// Creates a diagnostic with a generic type mismatch error
64    fn into_diagnostic_or_default<F>(
65        self,
66        unification_point: impl Into<FullSpan> + Clone,
67        message: Option<F>,
68        type_state: &TypeState,
69    ) -> std::result::Result<T, Diagnostic>
70    where
71        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
72    {
73        if let Some(message) = message {
74            self.into_diagnostic(unification_point, message, type_state)
75        } else {
76            self.into_diagnostic(unification_point, |d, _| d, type_state)
77        }
78    }
79
80    /// Creates a diagnostic from the unification error that will be emitted at the unification
81    /// point, unless the unification error was caused by constraints, at which point
82    /// the source of those constraints will be the location of the error.
83    /// If trait constraints were not met, a default message will be provided at the unification
84    /// point
85    fn into_diagnostic<F>(
86        self,
87        unification_point: impl Into<FullSpan> + Clone,
88        message: F,
89        type_state: &TypeState,
90    ) -> std::result::Result<T, Diagnostic>
91    where
92        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
93    {
94        self.into_diagnostic_impl(unification_point, false, message, type_state)
95    }
96
97    fn into_diagnostic_no_expected_source<F>(
98        self,
99        unification_point: impl Into<FullSpan> + Clone,
100        message: F,
101        type_state: &TypeState,
102    ) -> std::result::Result<T, Diagnostic>
103    where
104        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
105    {
106        self.into_diagnostic_impl(unification_point, true, message, type_state)
107    }
108
109    fn into_diagnostic_impl<F>(
110        self,
111        unification_point: impl Into<FullSpan> + Clone,
112        omit_expected_source: bool,
113        message: F,
114        type_state: &TypeState,
115    ) -> std::result::Result<T, Diagnostic>
116    where
117        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic;
118}
119impl<T> UnificationErrorExt<T> for std::result::Result<T, UnificationError> {
120    fn add_context(
121        self,
122        got: TypeVarID,
123        expected: TypeVarID,
124    ) -> std::result::Result<T, UnificationError> {
125        match self {
126            Ok(val) => Ok(val),
127            Err(UnificationError::Normal(TypeMismatch {
128                e: mut old_e,
129                g: mut old_g,
130            })) => {
131                old_e.inside.replace(expected);
132                old_g.inside.replace(got);
133                Err(UnificationError::Normal(TypeMismatch {
134                    e: old_e,
135                    g: old_g,
136                }))
137            }
138            Err(UnificationError::MetaMismatch(TypeMismatch {
139                e: mut old_e,
140                g: mut old_g,
141            })) => {
142                old_e.inside.replace(expected);
143                old_g.inside.replace(got);
144                Err(UnificationError::MetaMismatch(TypeMismatch {
145                    e: old_e,
146                    g: old_g,
147                }))
148            }
149            e @ Err(UnificationError::RecursiveType { .. }) => e,
150            e @ Err(UnificationError::UnsatisfiedTraits { .. }) => e,
151            e @ Err(
152                UnificationError::FromConstraints { .. } | UnificationError::Specific { .. },
153            ) => e,
154        }
155    }
156
157    fn into_diagnostic_impl<F>(
158        self,
159        unification_point: impl Into<FullSpan> + Clone,
160        omit_expected_source: bool,
161        message: F,
162        type_state: &TypeState,
163    ) -> std::result::Result<T, Diagnostic>
164    where
165        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
166    {
167        self.map_err(|err| {
168            let display_meta = match &err {
169                UnificationError::Normal { .. } => false,
170                UnificationError::MetaMismatch { .. } => true,
171                _ => false,
172            };
173            match err {
174                UnificationError::Normal(TypeMismatch { e, g })
175                | UnificationError::MetaMismatch(TypeMismatch { e, g }) => {
176                    let e_disp = e.display_with_meta(display_meta, type_state);
177                    let g_disp = g.display_with_meta(display_meta, type_state);
178                    let msg = format!("Expected type {e_disp}, got {g_disp}");
179                    let diag = Diagnostic::error(unification_point.clone(), msg)
180                        .primary_label(format!("Expected {e_disp}"));
181                    let diag = message(
182                        diag,
183                        TypeMismatch {
184                            e: e.clone(),
185                            g: g.clone(),
186                        },
187                    );
188
189                    let diag = if !omit_expected_source {
190                        add_known_type_context(
191                            diag,
192                            unification_point.clone(),
193                            &e.failing,
194                            display_meta,
195                            type_state,
196                        )
197                    } else {
198                        diag
199                    };
200
201                    let diag = add_known_type_context(
202                        diag,
203                        unification_point,
204                        &g.failing,
205                        display_meta,
206                        type_state,
207                    );
208                    diag.type_error(
209                        format!("{}", e.failing.display_with_meta(display_meta, type_state)),
210                        e.inside
211                            .map(|o| o.display_with_meta(display_meta, type_state)),
212                        format!("{}", g.failing.display_with_meta(display_meta, type_state)),
213                        g.inside
214                            .map(|o| o.display_with_meta(display_meta, type_state)),
215                    )
216                }
217                UnificationError::UnsatisfiedTraits {
218                    var,
219                    traits,
220                    target_loc: _,
221                } => {
222                    let trait_bound_loc = ().at_loc(&traits[0]);
223                    let impls_str = if traits.len() >= 2 {
224                        format!(
225                            "{} and {}",
226                            traits[0..traits.len() - 1]
227                                .iter()
228                                .map(|i| i.inner.display_with_meta(display_meta, type_state))
229                                .join(", "),
230                            traits[traits.len() - 1].display_with_meta(display_meta, type_state)
231                        )
232                    } else {
233                        format!("{}", traits[0].display_with_meta(display_meta, type_state))
234                    };
235                    let short_msg = format!(
236                        "{var} does not implement {impls_str}",
237                        var = var.display_with_meta(display_meta, type_state)
238                    );
239                    Diagnostic::error(
240                        unification_point,
241                        format!("Trait bound not satisfied. {short_msg}"),
242                    )
243                    .primary_label(short_msg)
244                    .secondary_label(
245                        trait_bound_loc,
246                        "Required because of the trait bound specified here",
247                    )
248                }
249                UnificationError::FromConstraints {
250                    expected,
251                    got,
252                    source,
253                    loc,
254                    is_meta_error,
255                } => {
256                    let diag = Diagnostic::error(
257                        loc,
258                        format!(
259                            "Expected type {}, got {}",
260                            expected.display_with_meta(is_meta_error, type_state),
261                            got.display_with_meta(is_meta_error, type_state)
262                        ),
263                    )
264                    .primary_label(format!(
265                        "Expected {}, got {}",
266                        expected.display_with_meta(is_meta_error, type_state),
267                        got.display_with_meta(is_meta_error, type_state)
268                    ));
269
270                    let diag = diag.type_error(
271                        format!(
272                            "{}",
273                            expected
274                                .failing
275                                .display_with_meta(is_meta_error, type_state)
276                        ),
277                        expected
278                            .inside
279                            .as_ref()
280                            .map(|o| o.display_with_meta(is_meta_error, type_state)),
281                        format!(
282                            "{}",
283                            got.failing.display_with_meta(is_meta_error, type_state)
284                        ),
285                        got.inside
286                            .as_ref()
287                            .map(|o| o.display_with_meta(is_meta_error, type_state)),
288                    );
289
290                    match source {
291                        ConstraintSource::AdditionOutput => diag.note(
292                            "Addition creates one more output bit than the input to avoid overflow"
293                                .to_string(),
294                        ),
295                        ConstraintSource::MultOutput => diag.note(
296                            "The size of a multiplication is the sum of the operand sizes"
297                                .to_string(),
298                        ),
299                        ConstraintSource::ArrayIndexing => {
300                            // NOTE: This error message could probably be improved
301                            diag.note(
302                                "because the value is used as an index to an array".to_string(),
303                            )
304                        }
305                        ConstraintSource::MemoryIndexing => {
306                            // NOTE: This error message could probably be improved
307                            diag.note(
308                                "because the value is used as an index to a memory".to_string(),
309                            )
310                        }
311                        ConstraintSource::Concatenation => diag.note(
312                            "The size of a concatenation is the sum of the operand sizes"
313                                .to_string(),
314                        ),
315                        ConstraintSource::ArraySize => {
316                            diag.note("The number of array elements must  match")
317                        }
318                        ConstraintSource::RangeIndex => diag,
319                        ConstraintSource::RangeIndexOutputSize => diag.note(
320                            "The output of a range index is an array inferred from the indices",
321                        ),
322                        ConstraintSource::TypeLevelIf | ConstraintSource::Where => diag,
323                        ConstraintSource::PipelineRegOffset { .. } => diag,
324                        ConstraintSource::PipelineRegCount { reg, total } => Diagnostic::error(
325                            total,
326                            format!(
327                                "Expected {expected} in this pipeline.",
328                                expected = expected.display(type_state)
329                            ),
330                        )
331                        .primary_label(format!(
332                            "Expected {expected} stages",
333                            expected = expected.display(type_state)
334                        ))
335                        .secondary_label(
336                            reg,
337                            format!(
338                                "This final register is number {expected}",
339                                expected = got.display(type_state)
340                            ),
341                        ),
342                        ConstraintSource::PipelineAvailDepth => diag,
343                    }
344                }
345                UnificationError::Specific(e) => e.secondary_label(
346                    unification_point,
347                    "The error occurred when unifying types here",
348                ),
349                UnificationError::RecursiveType(t) => {
350                    Diagnostic::error(unification_point, format!("Inferred a recursive type {t}"))
351                        .primary_label(format!("Type {t} inferred here"))
352                        .help("The recursion happens at the type marked \"*\"")
353                }
354            }
355        })
356    }
357}
358
359fn add_known_type_context(
360    diag: Diagnostic,
361    unification_point: impl Into<FullSpan> + Clone,
362    failing: &TypeVarID,
363    meta: bool,
364    type_state: &TypeState,
365) -> Diagnostic {
366    match failing.resolve(type_state) {
367        TypeVar::Known(k, _, _) => {
368            if FullSpan::from(k) != unification_point.clone().into() {
369                diag.secondary_label(
370                    k,
371                    format!(
372                        "Type {} inferred here",
373                        failing.display_with_meta(meta, type_state)
374                    ),
375                )
376            } else {
377                diag
378            }
379        }
380        TypeVar::Unknown(k, _, _, _) => {
381            if FullSpan::from(k) != unification_point.clone().into() {
382                diag.secondary_label(
383                    k,
384                    format!(
385                        "Type {} inferred here",
386                        failing.display_with_meta(meta, type_state)
387                    ),
388                )
389            } else {
390                diag
391            }
392        }
393    }
394}
395
396#[derive(Debug, PartialEq, Clone)]
397pub struct TypeMismatch {
398    /// Expected type
399    pub e: UnificationTrace,
400    /// Got type
401    pub g: UnificationTrace,
402}
403impl TypeMismatch {
404    pub fn is_meta_error(&self, type_state: &TypeState) -> bool {
405        matches!(
406            self.e.failing.resolve(type_state),
407            TypeVar::Unknown(_, _, _, _)
408        ) || matches!(
409            self.g.failing.resolve(type_state),
410            TypeVar::Unknown(_, _, _, _)
411        )
412    }
413
414    pub fn display_e_g(&self, type_state: &TypeState) -> (String, String) {
415        let is_meta = self.is_meta_error(type_state);
416        (
417            self.e.display_with_meta(is_meta, type_state),
418            self.g.display_with_meta(is_meta, type_state),
419        )
420    }
421}
422
423#[derive(Debug, PartialEq, Clone)]
424pub enum UnificationError {
425    Normal(TypeMismatch),
426    MetaMismatch(TypeMismatch),
427    RecursiveType(String),
428    Specific(spade_diagnostics::Diagnostic),
429    UnsatisfiedTraits {
430        var: TypeVarID,
431        traits: Vec<Loc<TraitReq>>,
432        target_loc: Loc<()>,
433    },
434    FromConstraints {
435        expected: UnificationTrace,
436        got: UnificationTrace,
437        source: ConstraintSource,
438        loc: Loc<()>,
439        is_meta_error: bool,
440    },
441}
442
443pub type Result<T> = std::result::Result<T, Diagnostic>;
444
445pub fn error_pattern_type_mismatch(
446    reason: Loc<()>,
447    type_state: &TypeState,
448) -> impl Fn(Diagnostic, TypeMismatch) -> Diagnostic + '_ {
449    move |diag, tm| {
450        let (expected, got) = tm.display_e_g(type_state);
451        diag.message(format!(
452            "Pattern type mismatch. Expected {expected} got {got}"
453        ))
454        .primary_label(format!("expected {expected}, got {got}"))
455        .secondary_label(reason, format!("because this has type {expected}"))
456    }
457}