spade_typeinference/
error.rs

1use itertools::Itertools;
2
3use spade_common::location_info::{FullSpan, Loc, WithLocation};
4use spade_diagnostics::Diagnostic;
5
6use crate::constraints::ConstraintSource;
7
8use super::equation::{TraitReq, TypeVar};
9
10/// A trace of a unification error. The `failing` field indicates which exact type failed to unify,
11/// while the `inside` is the "top level" type which failed to unify if it's not the same as
12/// failing.
13///
14/// For example, if unifying `int<7>` with `int<8>`, this would be `failing: 8, inside: int<8>`
15/// while if unifying `int<7>` with `bool`, inside would be `None`
16#[derive(Debug, PartialEq, Clone)]
17pub struct UnificationTrace {
18    pub failing: TypeVar,
19    pub inside: Option<TypeVar>,
20}
21impl WithLocation for UnificationTrace {}
22impl std::fmt::Display for UnificationTrace {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        write!(f, "{}", self.outer())
25    }
26}
27
28impl UnificationTrace {
29    pub fn new(failing: TypeVar) -> Self {
30        Self {
31            failing,
32            inside: None,
33        }
34    }
35
36    pub fn outer(&self) -> &TypeVar {
37        self.inside.as_ref().unwrap_or(&self.failing)
38    }
39
40    pub fn display_with_meta(&self, meta: bool) -> String {
41        self.inside
42            .as_ref()
43            .unwrap_or(&self.failing)
44            .display_with_meta(meta)
45    }
46}
47pub trait UnificationErrorExt<T>: Sized {
48    fn add_context(
49        self,
50        got: TypeVar,
51        expected: TypeVar,
52    ) -> std::result::Result<T, UnificationError>;
53
54    /// Creates a diagnostic with a generic type mismatch error
55    fn into_default_diagnostic(
56        self,
57        unification_point: impl Into<FullSpan> + Clone,
58    ) -> std::result::Result<T, Diagnostic> {
59        self.into_diagnostic(unification_point, |d, _| d)
60    }
61
62    /// Creates a diagnostic with a generic type mismatch error
63    fn into_diagnostic_or_default<F>(
64        self,
65        unification_point: impl Into<FullSpan> + Clone,
66        message: Option<F>,
67    ) -> std::result::Result<T, Diagnostic>
68    where
69        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
70    {
71        if let Some(message) = message {
72            self.into_diagnostic(unification_point, message)
73        } else {
74            self.into_diagnostic(unification_point, |d, _| d)
75        }
76    }
77
78    /// Creates a diagnostic from the unification error that will be emitted at the unification
79    /// point, unless the unification error was caused by constraints, at which point
80    /// the source of those constraints will be the location of the error.
81    /// If trait constraints were not met, a default message will be provided at the unification
82    /// point
83    fn into_diagnostic<F>(
84        self,
85        unification_point: impl Into<FullSpan> + Clone,
86        message: F,
87    ) -> std::result::Result<T, Diagnostic>
88    where
89        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
90    {
91        self.into_diagnostic_impl(unification_point, false, message)
92    }
93
94    fn into_diagnostic_no_expected_source<F>(
95        self,
96        unification_point: impl Into<FullSpan> + Clone,
97        message: F,
98    ) -> std::result::Result<T, Diagnostic>
99    where
100        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
101    {
102        self.into_diagnostic_impl(unification_point, true, message)
103    }
104
105    fn into_diagnostic_impl<F>(
106        self,
107        unification_point: impl Into<FullSpan> + Clone,
108        omit_expected_source: bool,
109        message: F,
110    ) -> std::result::Result<T, Diagnostic>
111    where
112        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic;
113}
114impl<T> UnificationErrorExt<T> for std::result::Result<T, UnificationError> {
115    fn add_context(
116        self,
117        got: TypeVar,
118        expected: TypeVar,
119    ) -> std::result::Result<T, UnificationError> {
120        match self {
121            Ok(val) => Ok(val),
122            Err(UnificationError::Normal(TypeMismatch {
123                e: mut old_e,
124                g: mut old_g,
125            })) => {
126                old_e.inside.replace(expected);
127                old_g.inside.replace(got);
128                Err(UnificationError::Normal(TypeMismatch {
129                    e: old_e,
130                    g: old_g,
131                }))
132            }
133            Err(UnificationError::MetaMismatch(TypeMismatch {
134                e: mut old_e,
135                g: mut old_g,
136            })) => {
137                old_e.inside.replace(expected);
138                old_g.inside.replace(got);
139                Err(UnificationError::MetaMismatch(TypeMismatch {
140                    e: old_e,
141                    g: old_g,
142                }))
143            }
144            e @ Err(UnificationError::UnsatisfiedTraits { .. }) => e,
145            e @ Err(
146                UnificationError::FromConstraints { .. } | UnificationError::Specific { .. },
147            ) => e,
148        }
149    }
150
151    fn into_diagnostic_impl<F>(
152        self,
153        unification_point: impl Into<FullSpan> + Clone,
154        omit_expected_source: bool,
155        message: F,
156    ) -> std::result::Result<T, Diagnostic>
157    where
158        F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
159    {
160        self.map_err(|err| {
161            let display_meta = match &err {
162                UnificationError::Normal { .. } => false,
163                UnificationError::MetaMismatch { .. } => true,
164                _ => false,
165            };
166            match err {
167                UnificationError::Normal(TypeMismatch { e, g })
168                | UnificationError::MetaMismatch(TypeMismatch { e, g }) => {
169                    let e_disp = e.display_with_meta(display_meta);
170                    let g_disp = g.display_with_meta(display_meta);
171                    let msg = format!("Expected type {e_disp}, got {g_disp}");
172                    let diag = Diagnostic::error(unification_point.clone(), msg)
173                        .primary_label(format!("Expected {e_disp}"));
174                    let diag = message(
175                        diag,
176                        TypeMismatch {
177                            e: e.clone(),
178                            g: g.clone(),
179                        },
180                    );
181
182                    let diag = if !omit_expected_source {
183                        add_known_type_context(
184                            diag,
185                            unification_point.clone(),
186                            &e.failing,
187                            display_meta,
188                        )
189                    } else {
190                        diag
191                    };
192
193                    let diag =
194                        add_known_type_context(diag, unification_point, &g.failing, display_meta);
195                    diag.type_error(
196                        format!("{}", e.failing.display_with_meta(display_meta)),
197                        e.inside.map(|o| o.display_with_meta(display_meta)),
198                        format!("{}", g.failing.display_with_meta(display_meta)),
199                        g.inside.map(|o| o.display_with_meta(display_meta)),
200                    )
201                }
202                UnificationError::UnsatisfiedTraits {
203                    var,
204                    traits,
205                    target_loc: _,
206                } => {
207                    let trait_bound_loc = ().at_loc(&traits[0]);
208                    let impls_str = if traits.len() >= 2 {
209                        format!(
210                            "{} and {}",
211                            traits[0..traits.len() - 1]
212                                .iter()
213                                .map(|i| i.inner.display_with_meta(display_meta))
214                                .join(", "),
215                            traits[traits.len() - 1]
216                        )
217                    } else {
218                        format!("{}", traits[0].display_with_meta(display_meta))
219                    };
220                    let short_msg = format!("{var} does not implement {impls_str}");
221                    Diagnostic::error(
222                        unification_point,
223                        format!("Trait bound not satisfied. {short_msg}"),
224                    )
225                    .primary_label(short_msg)
226                    .secondary_label(
227                        trait_bound_loc,
228                        "Required because of the trait bound specified here",
229                    )
230                }
231                UnificationError::FromConstraints {
232                    expected,
233                    got,
234                    source,
235                    loc,
236                    is_meta_error,
237                } => {
238                    let diag = Diagnostic::error(
239                        loc,
240                        format!(
241                            "Expected type {}, got {}",
242                            expected.display_with_meta(is_meta_error),
243                            got.display_with_meta(is_meta_error)
244                        ),
245                    )
246                    .primary_label(format!(
247                        "Expected {}, got {}",
248                        expected.display_with_meta(is_meta_error),
249                        got.display_with_meta(is_meta_error)
250                    ));
251
252                    let diag = diag.type_error(
253                        format!("{}", expected.failing.display_with_meta(is_meta_error)),
254                        expected
255                            .inside
256                            .as_ref()
257                            .map(|o| o.display_with_meta(is_meta_error)),
258                        format!("{}", got.failing.display_with_meta(is_meta_error)),
259                        got.inside
260                            .as_ref()
261                            .map(|o| o.display_with_meta(is_meta_error)),
262                    );
263
264                    match source {
265                        ConstraintSource::AdditionOutput => diag.note(
266                            "Addition creates one more output bit than the input to avoid overflow"
267                                .to_string(),
268                        ),
269                        ConstraintSource::MultOutput => diag.note(
270                            "The size of a multiplication is the sum of the operand sizes"
271                                .to_string(),
272                        ),
273                        ConstraintSource::ArrayIndexing => {
274                            // NOTE: This error message could probably be improved
275                            diag.note(
276                                "because the value is used as an index to an array".to_string(),
277                            )
278                        }
279                        ConstraintSource::MemoryIndexing => {
280                            // NOTE: This error message could probably be improved
281                            diag.note(
282                                "because the value is used as an index to a memory".to_string(),
283                            )
284                        }
285                        ConstraintSource::Concatenation => diag.note(
286                            "The size of a concatenation is the sum of the operand sizes"
287                                .to_string(),
288                        ),
289                        ConstraintSource::ArraySize => {
290                            diag.note("The number of array elements must  match")
291                        }
292                        ConstraintSource::RangeIndex => diag,
293                        ConstraintSource::RangeIndexOutputSize => diag.note(
294                            "The output of a range index is an array inferred from the indices",
295                        ),
296                        ConstraintSource::TypeLevelIf | ConstraintSource::Where => diag,
297                        ConstraintSource::PipelineRegOffset { .. } => diag,
298                        ConstraintSource::PipelineRegCount { reg, total } => Diagnostic::error(
299                            total,
300                            format!("Expected {expected} in this pipeline."),
301                        )
302                        .primary_label(format!("Expected {expected} stages"))
303                        .secondary_label(reg, format!("This final register is number {got}")),
304                        ConstraintSource::PipelineAvailDepth => diag,
305                    }
306                }
307                UnificationError::Specific(e) => e,
308            }
309        })
310    }
311}
312
313fn add_known_type_context(
314    diag: Diagnostic,
315    unification_point: impl Into<FullSpan> + Clone,
316    failing: &TypeVar,
317    meta: bool,
318) -> Diagnostic {
319    match failing {
320        TypeVar::Known(k, _, _) => {
321            if FullSpan::from(k) != unification_point.clone().into() {
322                diag.secondary_label(
323                    k,
324                    format!("Type {} inferred here", failing.display_with_meta(meta)),
325                )
326            } else {
327                diag
328            }
329        }
330        TypeVar::Unknown(k, _, _, _) => {
331            if FullSpan::from(k) != unification_point.clone().into() {
332                diag.secondary_label(
333                    k,
334                    format!("Type {} inferred here", failing.display_with_meta(meta)),
335                )
336            } else {
337                diag
338            }
339        }
340    }
341}
342
343#[derive(Debug, PartialEq, Clone)]
344pub struct TypeMismatch {
345    /// Expected type
346    pub e: UnificationTrace,
347    /// Got type
348    pub g: UnificationTrace,
349}
350impl TypeMismatch {
351    pub fn is_meta_error(&self) -> bool {
352        matches!(self.e.failing, TypeVar::Unknown(_, _, _, _))
353            || matches!(self.g.failing, TypeVar::Unknown(_, _, _, _))
354    }
355
356    pub fn display_e_g(&self) -> (String, String) {
357        let is_meta = self.is_meta_error();
358        (
359            self.e.display_with_meta(is_meta),
360            self.g.display_with_meta(is_meta),
361        )
362    }
363}
364impl std::fmt::Display for TypeMismatch {
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        write!(f, "expected {}, got {}", self.e, self.g)
367    }
368}
369
370#[derive(Debug, PartialEq, Clone)]
371pub enum UnificationError {
372    Normal(TypeMismatch),
373    MetaMismatch(TypeMismatch),
374    Specific(spade_diagnostics::Diagnostic),
375    UnsatisfiedTraits {
376        var: TypeVar,
377        traits: Vec<Loc<TraitReq>>,
378        target_loc: Loc<()>,
379    },
380    FromConstraints {
381        expected: UnificationTrace,
382        got: UnificationTrace,
383        source: ConstraintSource,
384        loc: Loc<()>,
385        is_meta_error: bool,
386    },
387}
388
389pub type Result<T> = std::result::Result<T, Diagnostic>;
390
391pub fn error_pattern_type_mismatch(
392    reason: Loc<()>,
393) -> impl Fn(Diagnostic, TypeMismatch) -> Diagnostic {
394    move |diag,
395          TypeMismatch {
396              e: expected,
397              g: got,
398          }| {
399        diag.message(format!(
400            "Pattern type mismatch. Expected {expected} got {got}"
401        ))
402        .primary_label(format!("expected {expected}, got {got}"))
403        .secondary_label(reason, format!("because this has type {expected}"))
404    }
405}