1use itertools::Itertools;
2
3use spade_common::location_info::{FullSpan, Loc, WithLocation};
4use spade_diagnostics::Diagnostic;
5
6use crate::{constraints::ConstraintSource, equation::TypeVarID, traits::TraitReq, TypeState};
7
8use super::equation::TypeVar;
9
10#[derive(Debug, PartialEq, Clone)]
17pub struct UnificationTrace {
18 pub failing: TypeVarID,
19 pub inside: Option<TypeVarID>,
20}
21
22impl UnificationTrace {
23 pub fn new(failing: TypeVarID) -> Self {
24 Self {
25 failing,
26 inside: None,
27 }
28 }
29
30 pub fn outer(&self) -> TypeVarID {
31 *self.inside.as_ref().unwrap_or(&self.failing)
32 }
33
34 pub fn display_with_meta(&self, meta: bool, type_state: &TypeState) -> String {
35 self.inside
36 .as_ref()
37 .unwrap_or(&self.failing)
38 .display_with_meta(meta, type_state)
39 }
40
41 pub fn display(&self, type_state: &TypeState) -> String {
42 self.outer().display_with_meta(false, type_state)
43 }
44}
45
46pub trait UnificationErrorExt<T>: Sized {
47 fn add_context(
48 self,
49 got: TypeVarID,
50 expected: TypeVarID,
51 ) -> std::result::Result<T, UnificationError>;
52
53 fn into_default_diagnostic(
55 self,
56 unification_point: impl Into<FullSpan> + Clone,
57 type_state: &TypeState,
58 ) -> std::result::Result<T, Diagnostic> {
59 self.into_diagnostic(unification_point, |d, _| d, type_state)
60 }
61
62 fn into_diagnostic_or_default<F>(
64 self,
65 unification_point: impl Into<FullSpan> + Clone,
66 message: Option<F>,
67 type_state: &TypeState,
68 ) -> std::result::Result<T, Diagnostic>
69 where
70 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
71 {
72 if let Some(message) = message {
73 self.into_diagnostic(unification_point, message, type_state)
74 } else {
75 self.into_diagnostic(unification_point, |d, _| d, type_state)
76 }
77 }
78
79 fn into_diagnostic<F>(
85 self,
86 unification_point: impl Into<FullSpan> + Clone,
87 message: F,
88 type_state: &TypeState,
89 ) -> std::result::Result<T, Diagnostic>
90 where
91 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
92 {
93 self.into_diagnostic_impl(unification_point, false, message, type_state)
94 }
95
96 fn into_diagnostic_no_expected_source<F>(
97 self,
98 unification_point: impl Into<FullSpan> + Clone,
99 message: F,
100 type_state: &TypeState,
101 ) -> std::result::Result<T, Diagnostic>
102 where
103 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
104 {
105 self.into_diagnostic_impl(unification_point, true, message, type_state)
106 }
107
108 fn into_diagnostic_impl<F>(
109 self,
110 unification_point: impl Into<FullSpan> + Clone,
111 omit_expected_source: bool,
112 message: F,
113 type_state: &TypeState,
114 ) -> std::result::Result<T, Diagnostic>
115 where
116 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic;
117}
118impl<T> UnificationErrorExt<T> for std::result::Result<T, UnificationError> {
119 fn add_context(
120 self,
121 got: TypeVarID,
122 expected: TypeVarID,
123 ) -> std::result::Result<T, UnificationError> {
124 match self {
125 Ok(val) => Ok(val),
126 Err(UnificationError::Normal(TypeMismatch {
127 e: mut old_e,
128 g: mut old_g,
129 })) => {
130 old_e.inside.replace(expected);
131 old_g.inside.replace(got);
132 Err(UnificationError::Normal(TypeMismatch {
133 e: old_e,
134 g: old_g,
135 }))
136 }
137 Err(UnificationError::MetaMismatch(TypeMismatch {
138 e: mut old_e,
139 g: mut old_g,
140 })) => {
141 old_e.inside.replace(expected);
142 old_g.inside.replace(got);
143 Err(UnificationError::MetaMismatch(TypeMismatch {
144 e: old_e,
145 g: old_g,
146 }))
147 }
148 e @ Err(UnificationError::RecursiveType { .. }) => e,
149 e @ Err(UnificationError::UnsatisfiedTraits { .. }) => e,
150 e @ Err(
151 UnificationError::FromConstraints { .. } | UnificationError::Specific { .. },
152 ) => e,
153 }
154 }
155
156 fn into_diagnostic_impl<F>(
157 self,
158 unification_point: impl Into<FullSpan> + Clone,
159 omit_expected_source: bool,
160 message: F,
161 type_state: &TypeState,
162 ) -> std::result::Result<T, Diagnostic>
163 where
164 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
165 {
166 self.map_err(|err| {
167 let display_meta = match &err {
168 UnificationError::Normal { .. } => false,
169 UnificationError::MetaMismatch { .. } => true,
170 _ => false,
171 };
172 match err {
173 UnificationError::Normal(TypeMismatch { e, g })
174 | UnificationError::MetaMismatch(TypeMismatch { e, g }) => {
175 let e_disp = e.display_with_meta(display_meta, type_state);
176 let g_disp = g.display_with_meta(display_meta, type_state);
177 let msg = format!("Expected type {e_disp}, got {g_disp}");
178 let diag = Diagnostic::error(unification_point.clone(), msg)
179 .primary_label(format!("Expected {e_disp}"));
180 let diag = message(
181 diag,
182 TypeMismatch {
183 e: e.clone(),
184 g: g.clone(),
185 },
186 );
187
188 let diag = if !omit_expected_source {
189 add_known_type_context(
190 diag,
191 unification_point.clone(),
192 &e.failing,
193 display_meta,
194 type_state,
195 )
196 } else {
197 diag
198 };
199
200 let diag = add_known_type_context(
201 diag,
202 unification_point,
203 &g.failing,
204 display_meta,
205 type_state,
206 );
207 diag.type_error(
208 format!("{}", e.failing.display_with_meta(display_meta, type_state)),
209 e.inside
210 .map(|o| o.display_with_meta(display_meta, type_state)),
211 format!("{}", g.failing.display_with_meta(display_meta, type_state)),
212 g.inside
213 .map(|o| o.display_with_meta(display_meta, type_state)),
214 )
215 }
216 UnificationError::UnsatisfiedTraits {
217 var,
218 traits,
219 target_loc: _,
220 } => {
221 let trait_bound_loc = ().at_loc(&traits[0]);
222 let impls_str = if traits.len() >= 2 {
223 format!(
224 "{} and {}",
225 traits[0..traits.len() - 1]
226 .iter()
227 .map(|i| i.inner.display_with_meta(display_meta, type_state))
228 .join(", "),
229 traits[traits.len() - 1].display_with_meta(display_meta, type_state)
230 )
231 } else {
232 format!("{}", traits[0].display_with_meta(display_meta, type_state))
233 };
234 let short_msg = format!(
235 "{var} does not implement {impls_str}",
236 var = var.display_with_meta(display_meta, type_state)
237 );
238 Diagnostic::error(
239 unification_point,
240 format!("Trait bound not satisfied. {short_msg}"),
241 )
242 .primary_label(short_msg)
243 .secondary_label(
244 trait_bound_loc,
245 "Required because of the trait bound specified here",
246 )
247 }
248 UnificationError::FromConstraints {
249 expected,
250 got,
251 source,
252 loc,
253 is_meta_error,
254 } => {
255 let diag = Diagnostic::error(
256 loc,
257 format!(
258 "Expected type {}, got {}",
259 expected.display_with_meta(is_meta_error, type_state),
260 got.display_with_meta(is_meta_error, type_state)
261 ),
262 )
263 .primary_label(format!(
264 "Expected {}, got {}",
265 expected.display_with_meta(is_meta_error, type_state),
266 got.display_with_meta(is_meta_error, type_state)
267 ));
268
269 let diag = diag.type_error(
270 format!(
271 "{}",
272 expected
273 .failing
274 .display_with_meta(is_meta_error, type_state)
275 ),
276 expected
277 .inside
278 .as_ref()
279 .map(|o| o.display_with_meta(is_meta_error, type_state)),
280 format!(
281 "{}",
282 got.failing.display_with_meta(is_meta_error, type_state)
283 ),
284 got.inside
285 .as_ref()
286 .map(|o| o.display_with_meta(is_meta_error, type_state)),
287 );
288
289 let diag = match source {
290 ConstraintSource::AdditionOutput => diag.note(
291 "Addition creates one more output bit than the input to avoid overflow"
292 .to_string(),
293 ),
294 ConstraintSource::MultOutput => diag.note(
295 "The size of a multiplication is the sum of the operand sizes"
296 .to_string(),
297 ),
298 ConstraintSource::ArrayIndexing => {
299 diag.note(
301 "because the value is used as an index to an array".to_string(),
302 )
303 }
304 ConstraintSource::MemoryIndexing => {
305 diag.note(
307 "because the value is used as an index to a memory".to_string(),
308 )
309 }
310 ConstraintSource::Concatenation => diag.note(
311 "The size of a concatenation is the sum of the operand sizes"
312 .to_string(),
313 ),
314 ConstraintSource::ArraySize => {
315 diag.note("The number of array elements must match")
316 }
317 ConstraintSource::RangeIndex => diag,
318 ConstraintSource::RangeIndexOutputSize => diag.note(
319 "The output of a range index is an array inferred from the indices",
320 ),
321 ConstraintSource::TypeLevelIf | ConstraintSource::Where => diag,
322 ConstraintSource::PipelineRegOffset { .. } => diag,
323 ConstraintSource::PipelineRegCount { reg, total } => Diagnostic::error(
324 total,
325 format!(
326 "Expected {expected} in this pipeline.",
327 expected = expected.display(type_state)
328 ),
329 )
330 .primary_label(format!(
331 "Expected {expected} stages",
332 expected = expected.display(type_state)
333 ))
334 .secondary_label(
335 reg,
336 format!(
337 "This final register is number {expected}",
338 expected = got.display(type_state)
339 ),
340 ),
341 ConstraintSource::PipelineAvailDepth => diag,
342 };
343
344 if unification_point.clone().into().0 != loc.span {
345 diag.secondary_label(
346 unification_point,
347 "The error occurred when unifying types here",
348 )
349 } else {
350 diag
351 }
352 }
353 UnificationError::Specific(e) => e.secondary_label(
354 unification_point,
355 "The error occurred when unifying types here",
356 ),
357 UnificationError::RecursiveType(t) => {
358 Diagnostic::error(unification_point, format!("Inferred a recursive type {t}"))
359 .primary_label(format!("Type {t} inferred here"))
360 .help("The recursion happens at the type marked \"*\"")
361 }
362 }
363 })
364 }
365}
366
367fn add_known_type_context(
368 diag: Diagnostic,
369 unification_point: impl Into<FullSpan> + Clone,
370 failing: &TypeVarID,
371 meta: bool,
372 type_state: &TypeState,
373) -> Diagnostic {
374 match failing.resolve(type_state) {
375 TypeVar::Known(k, _, _) => {
376 if FullSpan::from(k) != unification_point.clone().into() {
377 diag.secondary_label(
378 k,
379 format!(
380 "Type {} inferred here",
381 failing.display_with_meta(meta, type_state)
382 ),
383 )
384 } else {
385 diag
386 }
387 }
388 TypeVar::Unknown(k, _, _, _) => {
389 if FullSpan::from(k) != unification_point.clone().into() {
390 diag.secondary_label(
391 k,
392 format!(
393 "Type {} inferred here",
394 failing.display_with_meta(meta, type_state)
395 ),
396 )
397 } else {
398 diag
399 }
400 }
401 }
402}
403
404#[derive(Debug, PartialEq, Clone)]
405pub struct TypeMismatch {
406 pub e: UnificationTrace,
408 pub g: UnificationTrace,
410}
411impl TypeMismatch {
412 pub fn is_meta_error(&self, type_state: &TypeState) -> bool {
413 matches!(
414 self.e.failing.resolve(type_state),
415 TypeVar::Unknown(_, _, _, _)
416 ) || matches!(
417 self.g.failing.resolve(type_state),
418 TypeVar::Unknown(_, _, _, _)
419 )
420 }
421
422 pub fn display_e_g(&self, type_state: &TypeState) -> (String, String) {
423 let is_meta = self.is_meta_error(type_state);
424 (
425 self.e.display_with_meta(is_meta, type_state),
426 self.g.display_with_meta(is_meta, type_state),
427 )
428 }
429}
430
431#[derive(Debug, PartialEq, Clone)]
432pub enum UnificationError {
433 Normal(TypeMismatch),
434 MetaMismatch(TypeMismatch),
435 RecursiveType(String),
436 Specific(spade_diagnostics::Diagnostic),
437 UnsatisfiedTraits {
438 var: TypeVarID,
439 traits: Vec<Loc<TraitReq>>,
440 target_loc: Loc<()>,
441 },
442 FromConstraints {
443 expected: UnificationTrace,
444 got: UnificationTrace,
445 source: ConstraintSource,
446 loc: Loc<()>,
447 is_meta_error: bool,
448 },
449}
450
451pub type Result<T> = std::result::Result<T, Diagnostic>;
452
453pub fn error_pattern_type_mismatch(
454 reason: Loc<()>,
455 type_state: &TypeState,
456) -> impl Fn(Diagnostic, TypeMismatch) -> Diagnostic + '_ {
457 move |diag, tm| {
458 let (expected, got) = tm.display_e_g(type_state);
459 diag.message(format!(
460 "Pattern type mismatch. Expected {expected} got {got}"
461 ))
462 .primary_label(format!("expected {expected}, got {got}"))
463 .secondary_label(reason, format!("because this has type {expected}"))
464 }
465}