spade_typeinference/
error.rs

1use itertools::Itertools;
2
3use spade_common::location_info::{FullSpan, Loc, WithLocation};
4use spade_diagnostics::Diagnostic;
5
6use crate::{constraints::ConstraintSource, equation::TypeVarID, traits::TraitReq, TypeState};
7
8use super::equation::TypeVar;
9
10/// A trace of a unification error. The `failing` field indicates which exact type failed to unify,
11/// while the `inside` is the "top level" type which failed to unify if it's not the same as
12/// failing.
13///
14/// For example, if unifying `int<7>` with `int<8>`, this would be `failing: 8, inside: int<8>`
15/// while if unifying `int<7>` with `bool`, inside would be `None`
16#[derive(Debug, PartialEq, Clone)]
17pub struct UnificationTrace {
18    pub failing: TypeVarID,
19    pub inside: Option<TypeVarID>,
20}
21
22impl UnificationTrace {
23    pub fn new(failing: TypeVarID) -> Self {
24        Self {
25            failing,
26            inside: None,
27        }
28    }
29
30    pub fn outer(&self) -> TypeVarID {
31        *self.inside.as_ref().unwrap_or(&self.failing)
32    }
33
34    pub fn display_with_meta(&self, meta: bool, type_state: &TypeState) -> String {
35        self.inside
36            .as_ref()
37            .unwrap_or(&self.failing)
38            .display_with_meta(meta, type_state)
39    }
40
41    pub fn display(&self, type_state: &TypeState) -> String {
42        self.outer().display_with_meta(false, type_state)
43    }
44}
45
46pub trait UnificationErrorExt<T>: Sized {
47    fn add_context(
48        self,
49        got: TypeVarID,
50        expected: TypeVarID,
51    ) -> std::result::Result<T, UnificationError>;
52
53    /// Creates a diagnostic with a generic type mismatch error
54    fn into_default_diagnostic(
55        self,
56        unification_point: impl Into<FullSpan> + Clone,
57        type_state: &TypeState,
58    ) -> std::result::Result<T, Diagnostic> {
59        self.into_diagnostic(unification_point, |d, _| d, type_state)
60    }
61
62    /// Creates a diagnostic with a generic type mismatch error
63    fn into_diagnostic_or_default<F>(
64        self,
65        unification_point: impl Into<FullSpan> + Clone,
66        message: Option<F>,
67        type_state: &TypeState,
68    ) -> std::result::Result<T, Diagnostic>
69    where
70        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
71    {
72        if let Some(message) = message {
73            self.into_diagnostic(unification_point, message, type_state)
74        } else {
75            self.into_diagnostic(unification_point, |d, _| d, type_state)
76        }
77    }
78
79    /// Creates a diagnostic from the unification error that will be emitted at the unification
80    /// point, unless the unification error was caused by constraints, at which point
81    /// the source of those constraints will be the location of the error.
82    /// If trait constraints were not met, a default message will be provided at the unification
83    /// point
84    fn into_diagnostic<F>(
85        self,
86        unification_point: impl Into<FullSpan> + Clone,
87        message: F,
88        type_state: &TypeState,
89    ) -> std::result::Result<T, Diagnostic>
90    where
91        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
92    {
93        self.into_diagnostic_impl(unification_point, false, message, type_state)
94    }
95
96    fn into_diagnostic_no_expected_source<F>(
97        self,
98        unification_point: impl Into<FullSpan> + Clone,
99        message: F,
100        type_state: &TypeState,
101    ) -> std::result::Result<T, Diagnostic>
102    where
103        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
104    {
105        self.into_diagnostic_impl(unification_point, true, message, type_state)
106    }
107
108    fn into_diagnostic_impl<F>(
109        self,
110        unification_point: impl Into<FullSpan> + Clone,
111        omit_expected_source: bool,
112        message: F,
113        type_state: &TypeState,
114    ) -> std::result::Result<T, Diagnostic>
115    where
116        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic;
117}
118impl<T> UnificationErrorExt<T> for std::result::Result<T, UnificationError> {
119    fn add_context(
120        self,
121        got: TypeVarID,
122        expected: TypeVarID,
123    ) -> std::result::Result<T, UnificationError> {
124        match self {
125            Ok(val) => Ok(val),
126            Err(UnificationError::Normal(TypeMismatch {
127                e: mut old_e,
128                g: mut old_g,
129            })) => {
130                old_e.inside.replace(expected);
131                old_g.inside.replace(got);
132                Err(UnificationError::Normal(TypeMismatch {
133                    e: old_e,
134                    g: old_g,
135                }))
136            }
137            Err(UnificationError::MetaMismatch(TypeMismatch {
138                e: mut old_e,
139                g: mut old_g,
140            })) => {
141                old_e.inside.replace(expected);
142                old_g.inside.replace(got);
143                Err(UnificationError::MetaMismatch(TypeMismatch {
144                    e: old_e,
145                    g: old_g,
146                }))
147            }
148            e @ Err(UnificationError::RecursiveType { .. }) => e,
149            e @ Err(UnificationError::UnsatisfiedTraits { .. }) => e,
150            e @ Err(
151                UnificationError::FromConstraints { .. } | UnificationError::Specific { .. },
152            ) => e,
153        }
154    }
155
156    fn into_diagnostic_impl<F>(
157        self,
158        unification_point: impl Into<FullSpan> + Clone,
159        omit_expected_source: bool,
160        message: F,
161        type_state: &TypeState,
162    ) -> std::result::Result<T, Diagnostic>
163    where
164        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
165    {
166        self.map_err(|err| {
167            let display_meta = match &err {
168                UnificationError::Normal { .. } => false,
169                UnificationError::MetaMismatch { .. } => true,
170                _ => false,
171            };
172            match err {
173                UnificationError::Normal(TypeMismatch { e, g })
174                | UnificationError::MetaMismatch(TypeMismatch { e, g }) => {
175                    let e_disp = e.display_with_meta(display_meta, type_state);
176                    let g_disp = g.display_with_meta(display_meta, type_state);
177                    let msg = format!("Expected type {e_disp}, got {g_disp}");
178                    let diag = Diagnostic::error(unification_point.clone(), msg)
179                        .primary_label(format!("Expected {e_disp}"));
180                    let diag = message(
181                        diag,
182                        TypeMismatch {
183                            e: e.clone(),
184                            g: g.clone(),
185                        },
186                    );
187
188                    let diag = if !omit_expected_source {
189                        add_known_type_context(
190                            diag,
191                            unification_point.clone(),
192                            &e.failing,
193                            display_meta,
194                            type_state,
195                        )
196                    } else {
197                        diag
198                    };
199
200                    let diag = add_known_type_context(
201                        diag,
202                        unification_point,
203                        &g.failing,
204                        display_meta,
205                        type_state,
206                    );
207                    diag.type_error(
208                        format!("{}", e.failing.display_with_meta(display_meta, type_state)),
209                        e.inside
210                            .map(|o| o.display_with_meta(display_meta, type_state)),
211                        format!("{}", g.failing.display_with_meta(display_meta, type_state)),
212                        g.inside
213                            .map(|o| o.display_with_meta(display_meta, type_state)),
214                    )
215                }
216                UnificationError::UnsatisfiedTraits {
217                    var,
218                    traits,
219                    target_loc: _,
220                } => {
221                    let trait_bound_loc = ().at_loc(&traits[0]);
222                    let impls_str = if traits.len() >= 2 {
223                        format!(
224                            "{} and {}",
225                            traits[0..traits.len() - 1]
226                                .iter()
227                                .map(|i| i.inner.display_with_meta(display_meta, type_state))
228                                .join(", "),
229                            traits[traits.len() - 1].display_with_meta(display_meta, type_state)
230                        )
231                    } else {
232                        format!("{}", traits[0].display_with_meta(display_meta, type_state))
233                    };
234                    let short_msg = format!(
235                        "{var} does not implement {impls_str}",
236                        var = var.display_with_meta(display_meta, type_state)
237                    );
238                    Diagnostic::error(
239                        unification_point,
240                        format!("Trait bound not satisfied. {short_msg}"),
241                    )
242                    .primary_label(short_msg)
243                    .secondary_label(
244                        trait_bound_loc,
245                        "Required because of the trait bound specified here",
246                    )
247                }
248                UnificationError::FromConstraints {
249                    expected,
250                    got,
251                    source,
252                    loc,
253                    is_meta_error,
254                } => {
255                    let diag = Diagnostic::error(
256                        loc,
257                        format!(
258                            "Expected type {}, got {}",
259                            expected.display_with_meta(is_meta_error, type_state),
260                            got.display_with_meta(is_meta_error, type_state)
261                        ),
262                    )
263                    .primary_label(format!(
264                        "Expected {}, got {}",
265                        expected.display_with_meta(is_meta_error, type_state),
266                        got.display_with_meta(is_meta_error, type_state)
267                    ));
268
269                    let diag = diag.type_error(
270                        format!(
271                            "{}",
272                            expected
273                                .failing
274                                .display_with_meta(is_meta_error, type_state)
275                        ),
276                        expected
277                            .inside
278                            .as_ref()
279                            .map(|o| o.display_with_meta(is_meta_error, type_state)),
280                        format!(
281                            "{}",
282                            got.failing.display_with_meta(is_meta_error, type_state)
283                        ),
284                        got.inside
285                            .as_ref()
286                            .map(|o| o.display_with_meta(is_meta_error, type_state)),
287                    );
288
289                    let diag = match source {
290                        ConstraintSource::AdditionOutput => diag.note(
291                            "Addition creates one more output bit than the input to avoid overflow"
292                                .to_string(),
293                        ),
294                        ConstraintSource::MultOutput => diag.note(
295                            "The size of a multiplication is the sum of the operand sizes"
296                                .to_string(),
297                        ),
298                        ConstraintSource::ArrayIndexing => {
299                            // NOTE: This error message could probably be improved
300                            diag.note(
301                                "because the value is used as an index to an array".to_string(),
302                            )
303                        }
304                        ConstraintSource::MemoryIndexing => {
305                            // NOTE: This error message could probably be improved
306                            diag.note(
307                                "because the value is used as an index to a memory".to_string(),
308                            )
309                        }
310                        ConstraintSource::Concatenation => diag.note(
311                            "The size of a concatenation is the sum of the operand sizes"
312                                .to_string(),
313                        ),
314                        ConstraintSource::ArraySize => {
315                            diag.note("The number of array elements must  match")
316                        }
317                        ConstraintSource::RangeIndex => diag,
318                        ConstraintSource::RangeIndexOutputSize => diag.note(
319                            "The output of a range index is an array inferred from the indices",
320                        ),
321                        ConstraintSource::TypeLevelIf | ConstraintSource::Where => diag,
322                        ConstraintSource::PipelineRegOffset { .. } => diag,
323                        ConstraintSource::PipelineRegCount { reg, total } => Diagnostic::error(
324                            total,
325                            format!(
326                                "Expected {expected} in this pipeline.",
327                                expected = expected.display(type_state)
328                            ),
329                        )
330                        .primary_label(format!(
331                            "Expected {expected} stages",
332                            expected = expected.display(type_state)
333                        ))
334                        .secondary_label(
335                            reg,
336                            format!(
337                                "This final register is number {expected}",
338                                expected = got.display(type_state)
339                            ),
340                        ),
341                        ConstraintSource::PipelineAvailDepth => diag,
342                    };
343
344                    if unification_point.clone().into().0 != loc.span {
345                        diag.secondary_label(
346                            unification_point,
347                            "The error occurred when unifying types here",
348                        )
349                    } else {
350                        diag
351                    }
352                }
353                UnificationError::Specific(e) => e.secondary_label(
354                    unification_point,
355                    "The error occurred when unifying types here",
356                ),
357                UnificationError::RecursiveType(t) => {
358                    Diagnostic::error(unification_point, format!("Inferred a recursive type {t}"))
359                        .primary_label(format!("Type {t} inferred here"))
360                        .help("The recursion happens at the type marked \"*\"")
361                }
362            }
363        })
364    }
365}
366
367fn add_known_type_context(
368    diag: Diagnostic,
369    unification_point: impl Into<FullSpan> + Clone,
370    failing: &TypeVarID,
371    meta: bool,
372    type_state: &TypeState,
373) -> Diagnostic {
374    match failing.resolve(type_state) {
375        TypeVar::Known(k, _, _) => {
376            if FullSpan::from(k) != unification_point.clone().into() {
377                diag.secondary_label(
378                    k,
379                    format!(
380                        "Type {} inferred here",
381                        failing.display_with_meta(meta, type_state)
382                    ),
383                )
384            } else {
385                diag
386            }
387        }
388        TypeVar::Unknown(k, _, _, _) => {
389            if FullSpan::from(k) != unification_point.clone().into() {
390                diag.secondary_label(
391                    k,
392                    format!(
393                        "Type {} inferred here",
394                        failing.display_with_meta(meta, type_state)
395                    ),
396                )
397            } else {
398                diag
399            }
400        }
401    }
402}
403
404#[derive(Debug, PartialEq, Clone)]
405pub struct TypeMismatch {
406    /// Expected type
407    pub e: UnificationTrace,
408    /// Got type
409    pub g: UnificationTrace,
410}
411impl TypeMismatch {
412    pub fn is_meta_error(&self, type_state: &TypeState) -> bool {
413        matches!(
414            self.e.failing.resolve(type_state),
415            TypeVar::Unknown(_, _, _, _)
416        ) || matches!(
417            self.g.failing.resolve(type_state),
418            TypeVar::Unknown(_, _, _, _)
419        )
420    }
421
422    pub fn display_e_g(&self, type_state: &TypeState) -> (String, String) {
423        let is_meta = self.is_meta_error(type_state);
424        (
425            self.e.display_with_meta(is_meta, type_state),
426            self.g.display_with_meta(is_meta, type_state),
427        )
428    }
429}
430
431#[derive(Debug, PartialEq, Clone)]
432pub enum UnificationError {
433    Normal(TypeMismatch),
434    MetaMismatch(TypeMismatch),
435    RecursiveType(String),
436    Specific(spade_diagnostics::Diagnostic),
437    UnsatisfiedTraits {
438        var: TypeVarID,
439        traits: Vec<Loc<TraitReq>>,
440        target_loc: Loc<()>,
441    },
442    FromConstraints {
443        expected: UnificationTrace,
444        got: UnificationTrace,
445        source: ConstraintSource,
446        loc: Loc<()>,
447        is_meta_error: bool,
448    },
449}
450
451pub type Result<T> = std::result::Result<T, Diagnostic>;
452
453pub fn error_pattern_type_mismatch(
454    reason: Loc<()>,
455    type_state: &TypeState,
456) -> impl Fn(Diagnostic, TypeMismatch) -> Diagnostic + '_ {
457    move |diag, tm| {
458        let (expected, got) = tm.display_e_g(type_state);
459        diag.message(format!(
460            "Pattern type mismatch. Expected {expected} got {got}"
461        ))
462        .primary_label(format!("expected {expected}, got {got}"))
463        .secondary_label(reason, format!("because this has type {expected}"))
464    }
465}