1use itertools::Itertools;
2
3use spade_common::location_info::{FullSpan, Loc, WithLocation};
4use spade_diagnostics::Diagnostic;
5
6use crate::{constraints::ConstraintSource, equation::TypeVarID, traits::TraitReq, TypeState};
7
8use super::equation::TypeVar;
9
10#[derive(Debug, PartialEq, Clone)]
17pub struct UnificationTrace {
18 pub failing: TypeVarID,
19 pub inside: Option<TypeVarID>,
20}
21impl WithLocation for UnificationTrace {}
22
23impl UnificationTrace {
24 pub fn new(failing: TypeVarID) -> Self {
25 Self {
26 failing,
27 inside: None,
28 }
29 }
30
31 pub fn outer(&self) -> TypeVarID {
32 *self.inside.as_ref().unwrap_or(&self.failing)
33 }
34
35 pub fn display_with_meta(&self, meta: bool, type_state: &TypeState) -> String {
36 self.inside
37 .as_ref()
38 .unwrap_or(&self.failing)
39 .display_with_meta(meta, type_state)
40 }
41
42 pub fn display(&self, type_state: &TypeState) -> String {
43 self.outer().display_with_meta(false, type_state)
44 }
45}
46
47pub trait UnificationErrorExt<T>: Sized {
48 fn add_context(
49 self,
50 got: TypeVarID,
51 expected: TypeVarID,
52 ) -> std::result::Result<T, UnificationError>;
53
54 fn into_default_diagnostic(
56 self,
57 unification_point: impl Into<FullSpan> + Clone,
58 type_state: &TypeState,
59 ) -> std::result::Result<T, Diagnostic> {
60 self.into_diagnostic(unification_point, |d, _| d, type_state)
61 }
62
63 fn into_diagnostic_or_default<F>(
65 self,
66 unification_point: impl Into<FullSpan> + Clone,
67 message: Option<F>,
68 type_state: &TypeState,
69 ) -> std::result::Result<T, Diagnostic>
70 where
71 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
72 {
73 if let Some(message) = message {
74 self.into_diagnostic(unification_point, message, type_state)
75 } else {
76 self.into_diagnostic(unification_point, |d, _| d, type_state)
77 }
78 }
79
80 fn into_diagnostic<F>(
86 self,
87 unification_point: impl Into<FullSpan> + Clone,
88 message: F,
89 type_state: &TypeState,
90 ) -> std::result::Result<T, Diagnostic>
91 where
92 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
93 {
94 self.into_diagnostic_impl(unification_point, false, message, type_state)
95 }
96
97 fn into_diagnostic_no_expected_source<F>(
98 self,
99 unification_point: impl Into<FullSpan> + Clone,
100 message: F,
101 type_state: &TypeState,
102 ) -> std::result::Result<T, Diagnostic>
103 where
104 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
105 {
106 self.into_diagnostic_impl(unification_point, true, message, type_state)
107 }
108
109 fn into_diagnostic_impl<F>(
110 self,
111 unification_point: impl Into<FullSpan> + Clone,
112 omit_expected_source: bool,
113 message: F,
114 type_state: &TypeState,
115 ) -> std::result::Result<T, Diagnostic>
116 where
117 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic;
118}
119impl<T> UnificationErrorExt<T> for std::result::Result<T, UnificationError> {
120 fn add_context(
121 self,
122 got: TypeVarID,
123 expected: TypeVarID,
124 ) -> std::result::Result<T, UnificationError> {
125 match self {
126 Ok(val) => Ok(val),
127 Err(UnificationError::Normal(TypeMismatch {
128 e: mut old_e,
129 g: mut old_g,
130 })) => {
131 old_e.inside.replace(expected);
132 old_g.inside.replace(got);
133 Err(UnificationError::Normal(TypeMismatch {
134 e: old_e,
135 g: old_g,
136 }))
137 }
138 Err(UnificationError::MetaMismatch(TypeMismatch {
139 e: mut old_e,
140 g: mut old_g,
141 })) => {
142 old_e.inside.replace(expected);
143 old_g.inside.replace(got);
144 Err(UnificationError::MetaMismatch(TypeMismatch {
145 e: old_e,
146 g: old_g,
147 }))
148 }
149 e @ Err(UnificationError::RecursiveType { .. }) => e,
150 e @ Err(UnificationError::UnsatisfiedTraits { .. }) => e,
151 e @ Err(
152 UnificationError::FromConstraints { .. } | UnificationError::Specific { .. },
153 ) => e,
154 }
155 }
156
157 fn into_diagnostic_impl<F>(
158 self,
159 unification_point: impl Into<FullSpan> + Clone,
160 omit_expected_source: bool,
161 message: F,
162 type_state: &TypeState,
163 ) -> std::result::Result<T, Diagnostic>
164 where
165 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
166 {
167 self.map_err(|err| {
168 let display_meta = match &err {
169 UnificationError::Normal { .. } => false,
170 UnificationError::MetaMismatch { .. } => true,
171 _ => false,
172 };
173 match err {
174 UnificationError::Normal(TypeMismatch { e, g })
175 | UnificationError::MetaMismatch(TypeMismatch { e, g }) => {
176 let e_disp = e.display_with_meta(display_meta, type_state);
177 let g_disp = g.display_with_meta(display_meta, type_state);
178 let msg = format!("Expected type {e_disp}, got {g_disp}");
179 let diag = Diagnostic::error(unification_point.clone(), msg)
180 .primary_label(format!("Expected {e_disp}"));
181 let diag = message(
182 diag,
183 TypeMismatch {
184 e: e.clone(),
185 g: g.clone(),
186 },
187 );
188
189 let diag = if !omit_expected_source {
190 add_known_type_context(
191 diag,
192 unification_point.clone(),
193 &e.failing,
194 display_meta,
195 type_state,
196 )
197 } else {
198 diag
199 };
200
201 let diag = add_known_type_context(
202 diag,
203 unification_point,
204 &g.failing,
205 display_meta,
206 type_state,
207 );
208 diag.type_error(
209 format!("{}", e.failing.display_with_meta(display_meta, type_state)),
210 e.inside
211 .map(|o| o.display_with_meta(display_meta, type_state)),
212 format!("{}", g.failing.display_with_meta(display_meta, type_state)),
213 g.inside
214 .map(|o| o.display_with_meta(display_meta, type_state)),
215 )
216 }
217 UnificationError::UnsatisfiedTraits {
218 var,
219 traits,
220 target_loc: _,
221 } => {
222 let trait_bound_loc = ().at_loc(&traits[0]);
223 let impls_str = if traits.len() >= 2 {
224 format!(
225 "{} and {}",
226 traits[0..traits.len() - 1]
227 .iter()
228 .map(|i| i.inner.display_with_meta(display_meta, type_state))
229 .join(", "),
230 traits[traits.len() - 1].display_with_meta(display_meta, type_state)
231 )
232 } else {
233 format!("{}", traits[0].display_with_meta(display_meta, type_state))
234 };
235 let short_msg = format!(
236 "{var} does not implement {impls_str}",
237 var = var.display_with_meta(display_meta, type_state)
238 );
239 Diagnostic::error(
240 unification_point,
241 format!("Trait bound not satisfied. {short_msg}"),
242 )
243 .primary_label(short_msg)
244 .secondary_label(
245 trait_bound_loc,
246 "Required because of the trait bound specified here",
247 )
248 }
249 UnificationError::FromConstraints {
250 expected,
251 got,
252 source,
253 loc,
254 is_meta_error,
255 } => {
256 let diag = Diagnostic::error(
257 loc,
258 format!(
259 "Expected type {}, got {}",
260 expected.display_with_meta(is_meta_error, type_state),
261 got.display_with_meta(is_meta_error, type_state)
262 ),
263 )
264 .primary_label(format!(
265 "Expected {}, got {}",
266 expected.display_with_meta(is_meta_error, type_state),
267 got.display_with_meta(is_meta_error, type_state)
268 ));
269
270 let diag = diag.type_error(
271 format!(
272 "{}",
273 expected
274 .failing
275 .display_with_meta(is_meta_error, type_state)
276 ),
277 expected
278 .inside
279 .as_ref()
280 .map(|o| o.display_with_meta(is_meta_error, type_state)),
281 format!(
282 "{}",
283 got.failing.display_with_meta(is_meta_error, type_state)
284 ),
285 got.inside
286 .as_ref()
287 .map(|o| o.display_with_meta(is_meta_error, type_state)),
288 );
289
290 match source {
291 ConstraintSource::AdditionOutput => diag.note(
292 "Addition creates one more output bit than the input to avoid overflow"
293 .to_string(),
294 ),
295 ConstraintSource::MultOutput => diag.note(
296 "The size of a multiplication is the sum of the operand sizes"
297 .to_string(),
298 ),
299 ConstraintSource::ArrayIndexing => {
300 diag.note(
302 "because the value is used as an index to an array".to_string(),
303 )
304 }
305 ConstraintSource::MemoryIndexing => {
306 diag.note(
308 "because the value is used as an index to a memory".to_string(),
309 )
310 }
311 ConstraintSource::Concatenation => diag.note(
312 "The size of a concatenation is the sum of the operand sizes"
313 .to_string(),
314 ),
315 ConstraintSource::ArraySize => {
316 diag.note("The number of array elements must match")
317 }
318 ConstraintSource::RangeIndex => diag,
319 ConstraintSource::RangeIndexOutputSize => diag.note(
320 "The output of a range index is an array inferred from the indices",
321 ),
322 ConstraintSource::TypeLevelIf | ConstraintSource::Where => diag,
323 ConstraintSource::PipelineRegOffset { .. } => diag,
324 ConstraintSource::PipelineRegCount { reg, total } => Diagnostic::error(
325 total,
326 format!(
327 "Expected {expected} in this pipeline.",
328 expected = expected.display(type_state)
329 ),
330 )
331 .primary_label(format!(
332 "Expected {expected} stages",
333 expected = expected.display(type_state)
334 ))
335 .secondary_label(
336 reg,
337 format!(
338 "This final register is number {expected}",
339 expected = got.display(type_state)
340 ),
341 ),
342 ConstraintSource::PipelineAvailDepth => diag,
343 }
344 }
345 UnificationError::Specific(e) => e.secondary_label(
346 unification_point,
347 "The error occurred when unifying types here",
348 ),
349 UnificationError::RecursiveType(t) => {
350 Diagnostic::error(unification_point, format!("Inferred a recursive type {t}"))
351 .primary_label(format!("Type {t} inferred here"))
352 .help("The recursion happens at the type marked \"*\"")
353 }
354 }
355 })
356 }
357}
358
359fn add_known_type_context(
360 diag: Diagnostic,
361 unification_point: impl Into<FullSpan> + Clone,
362 failing: &TypeVarID,
363 meta: bool,
364 type_state: &TypeState,
365) -> Diagnostic {
366 match failing.resolve(type_state) {
367 TypeVar::Known(k, _, _) => {
368 if FullSpan::from(k) != unification_point.clone().into() {
369 diag.secondary_label(
370 k,
371 format!(
372 "Type {} inferred here",
373 failing.display_with_meta(meta, type_state)
374 ),
375 )
376 } else {
377 diag
378 }
379 }
380 TypeVar::Unknown(k, _, _, _) => {
381 if FullSpan::from(k) != unification_point.clone().into() {
382 diag.secondary_label(
383 k,
384 format!(
385 "Type {} inferred here",
386 failing.display_with_meta(meta, type_state)
387 ),
388 )
389 } else {
390 diag
391 }
392 }
393 }
394}
395
396#[derive(Debug, PartialEq, Clone)]
397pub struct TypeMismatch {
398 pub e: UnificationTrace,
400 pub g: UnificationTrace,
402}
403impl TypeMismatch {
404 pub fn is_meta_error(&self, type_state: &TypeState) -> bool {
405 matches!(
406 self.e.failing.resolve(type_state),
407 TypeVar::Unknown(_, _, _, _)
408 ) || matches!(
409 self.g.failing.resolve(type_state),
410 TypeVar::Unknown(_, _, _, _)
411 )
412 }
413
414 pub fn display_e_g(&self, type_state: &TypeState) -> (String, String) {
415 let is_meta = self.is_meta_error(type_state);
416 (
417 self.e.display_with_meta(is_meta, type_state),
418 self.g.display_with_meta(is_meta, type_state),
419 )
420 }
421}
422
423#[derive(Debug, PartialEq, Clone)]
424pub enum UnificationError {
425 Normal(TypeMismatch),
426 MetaMismatch(TypeMismatch),
427 RecursiveType(String),
428 Specific(spade_diagnostics::Diagnostic),
429 UnsatisfiedTraits {
430 var: TypeVarID,
431 traits: Vec<Loc<TraitReq>>,
432 target_loc: Loc<()>,
433 },
434 FromConstraints {
435 expected: UnificationTrace,
436 got: UnificationTrace,
437 source: ConstraintSource,
438 loc: Loc<()>,
439 is_meta_error: bool,
440 },
441}
442
443pub type Result<T> = std::result::Result<T, Diagnostic>;
444
445pub fn error_pattern_type_mismatch(
446 reason: Loc<()>,
447 type_state: &TypeState,
448) -> impl Fn(Diagnostic, TypeMismatch) -> Diagnostic + '_ {
449 move |diag, tm| {
450 let (expected, got) = tm.display_e_g(type_state);
451 diag.message(format!(
452 "Pattern type mismatch. Expected {expected} got {got}"
453 ))
454 .primary_label(format!("expected {expected}, got {got}"))
455 .secondary_label(reason, format!("because this has type {expected}"))
456 }
457}