1use itertools::Itertools;
2
3use spade_common::location_info::{FullSpan, Loc, WithLocation};
4use spade_diagnostics::Diagnostic;
5
6use crate::constraints::ConstraintSource;
7
8use super::equation::{TraitReq, TypeVar};
9
10#[derive(Debug, PartialEq, Clone)]
17pub struct UnificationTrace {
18 pub failing: TypeVar,
19 pub inside: Option<TypeVar>,
20}
21impl WithLocation for UnificationTrace {}
22impl std::fmt::Display for UnificationTrace {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.outer())
25 }
26}
27
28impl UnificationTrace {
29 pub fn new(failing: TypeVar) -> Self {
30 Self {
31 failing,
32 inside: None,
33 }
34 }
35
36 pub fn outer(&self) -> &TypeVar {
37 self.inside.as_ref().unwrap_or(&self.failing)
38 }
39
40 pub fn display_with_meta(&self, meta: bool) -> String {
41 self.inside
42 .as_ref()
43 .unwrap_or(&self.failing)
44 .display_with_meta(meta)
45 }
46}
47pub trait UnificationErrorExt<T>: Sized {
48 fn add_context(
49 self,
50 got: TypeVar,
51 expected: TypeVar,
52 ) -> std::result::Result<T, UnificationError>;
53
54 fn into_default_diagnostic(
56 self,
57 unification_point: impl Into<FullSpan> + Clone,
58 ) -> std::result::Result<T, Diagnostic> {
59 self.into_diagnostic(unification_point, |d, _| d)
60 }
61
62 fn into_diagnostic_or_default<F>(
64 self,
65 unification_point: impl Into<FullSpan> + Clone,
66 message: Option<F>,
67 ) -> std::result::Result<T, Diagnostic>
68 where
69 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
70 {
71 if let Some(message) = message {
72 self.into_diagnostic(unification_point, message)
73 } else {
74 self.into_diagnostic(unification_point, |d, _| d)
75 }
76 }
77
78 fn into_diagnostic<F>(
84 self,
85 unification_point: impl Into<FullSpan> + Clone,
86 message: F,
87 ) -> std::result::Result<T, Diagnostic>
88 where
89 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
90 {
91 self.into_diagnostic_impl(unification_point, false, message)
92 }
93
94 fn into_diagnostic_no_expected_source<F>(
95 self,
96 unification_point: impl Into<FullSpan> + Clone,
97 message: F,
98 ) -> std::result::Result<T, Diagnostic>
99 where
100 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
101 {
102 self.into_diagnostic_impl(unification_point, true, message)
103 }
104
105 fn into_diagnostic_impl<F>(
106 self,
107 unification_point: impl Into<FullSpan> + Clone,
108 omit_expected_source: bool,
109 message: F,
110 ) -> std::result::Result<T, Diagnostic>
111 where
112 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic;
113}
114impl<T> UnificationErrorExt<T> for std::result::Result<T, UnificationError> {
115 fn add_context(
116 self,
117 got: TypeVar,
118 expected: TypeVar,
119 ) -> std::result::Result<T, UnificationError> {
120 match self {
121 Ok(val) => Ok(val),
122 Err(UnificationError::Normal(TypeMismatch {
123 e: mut old_e,
124 g: mut old_g,
125 })) => {
126 old_e.inside.replace(expected);
127 old_g.inside.replace(got);
128 Err(UnificationError::Normal(TypeMismatch {
129 e: old_e,
130 g: old_g,
131 }))
132 }
133 Err(UnificationError::MetaMismatch(TypeMismatch {
134 e: mut old_e,
135 g: mut old_g,
136 })) => {
137 old_e.inside.replace(expected);
138 old_g.inside.replace(got);
139 Err(UnificationError::MetaMismatch(TypeMismatch {
140 e: old_e,
141 g: old_g,
142 }))
143 }
144 e @ Err(UnificationError::UnsatisfiedTraits { .. }) => e,
145 e @ Err(
146 UnificationError::FromConstraints { .. } | UnificationError::Specific { .. },
147 ) => e,
148 }
149 }
150
151 fn into_diagnostic_impl<F>(
152 self,
153 unification_point: impl Into<FullSpan> + Clone,
154 omit_expected_source: bool,
155 message: F,
156 ) -> std::result::Result<T, Diagnostic>
157 where
158 F: Fn(Diagnostic, TypeMismatch) -> Diagnostic,
159 {
160 self.map_err(|err| {
161 let display_meta = match &err {
162 UnificationError::Normal { .. } => false,
163 UnificationError::MetaMismatch { .. } => true,
164 _ => false,
165 };
166 match err {
167 UnificationError::Normal(TypeMismatch { e, g })
168 | UnificationError::MetaMismatch(TypeMismatch { e, g }) => {
169 let e_disp = e.display_with_meta(display_meta);
170 let g_disp = g.display_with_meta(display_meta);
171 let msg = format!("Expected type {e_disp}, got {g_disp}");
172 let diag = Diagnostic::error(unification_point.clone(), msg)
173 .primary_label(format!("Expected {e_disp}"));
174 let diag = message(
175 diag,
176 TypeMismatch {
177 e: e.clone(),
178 g: g.clone(),
179 },
180 );
181
182 let diag = if !omit_expected_source {
183 add_known_type_context(
184 diag,
185 unification_point.clone(),
186 &e.failing,
187 display_meta,
188 )
189 } else {
190 diag
191 };
192
193 let diag =
194 add_known_type_context(diag, unification_point, &g.failing, display_meta);
195 diag.type_error(
196 format!("{}", e.failing.display_with_meta(display_meta)),
197 e.inside.map(|o| o.display_with_meta(display_meta)),
198 format!("{}", g.failing.display_with_meta(display_meta)),
199 g.inside.map(|o| o.display_with_meta(display_meta)),
200 )
201 }
202 UnificationError::UnsatisfiedTraits {
203 var,
204 traits,
205 target_loc: _,
206 } => {
207 let trait_bound_loc = ().at_loc(&traits[0]);
208 let impls_str = if traits.len() >= 2 {
209 format!(
210 "{} and {}",
211 traits[0..traits.len() - 1]
212 .iter()
213 .map(|i| i.inner.display_with_meta(display_meta))
214 .join(", "),
215 traits[traits.len() - 1]
216 )
217 } else {
218 format!("{}", traits[0].display_with_meta(display_meta))
219 };
220 let short_msg = format!("{var} does not implement {impls_str}");
221 Diagnostic::error(
222 unification_point,
223 format!("Trait bound not satisfied. {short_msg}"),
224 )
225 .primary_label(short_msg)
226 .secondary_label(
227 trait_bound_loc,
228 "Required because of the trait bound specified here",
229 )
230 }
231 UnificationError::FromConstraints {
232 expected,
233 got,
234 source,
235 loc,
236 is_meta_error,
237 } => {
238 let diag = Diagnostic::error(
239 loc,
240 format!(
241 "Expected type {}, got {}",
242 expected.display_with_meta(is_meta_error),
243 got.display_with_meta(is_meta_error)
244 ),
245 )
246 .primary_label(format!(
247 "Expected {}, got {}",
248 expected.display_with_meta(is_meta_error),
249 got.display_with_meta(is_meta_error)
250 ));
251
252 let diag = diag.type_error(
253 format!("{}", expected.failing.display_with_meta(is_meta_error)),
254 expected
255 .inside
256 .as_ref()
257 .map(|o| o.display_with_meta(is_meta_error)),
258 format!("{}", got.failing.display_with_meta(is_meta_error)),
259 got.inside
260 .as_ref()
261 .map(|o| o.display_with_meta(is_meta_error)),
262 );
263
264 match source {
265 ConstraintSource::AdditionOutput => diag.note(
266 "Addition creates one more output bit than the input to avoid overflow"
267 .to_string(),
268 ),
269 ConstraintSource::MultOutput => diag.note(
270 "The size of a multiplication is the sum of the operand sizes"
271 .to_string(),
272 ),
273 ConstraintSource::ArrayIndexing => {
274 diag.note(
276 "because the value is used as an index to an array".to_string(),
277 )
278 }
279 ConstraintSource::MemoryIndexing => {
280 diag.note(
282 "because the value is used as an index to a memory".to_string(),
283 )
284 }
285 ConstraintSource::Concatenation => diag.note(
286 "The size of a concatenation is the sum of the operand sizes"
287 .to_string(),
288 ),
289 ConstraintSource::ArraySize => {
290 diag.note("The number of array elements must match")
291 }
292 ConstraintSource::RangeIndex => diag,
293 ConstraintSource::RangeIndexOutputSize => diag.note(
294 "The output of a range index is an array inferred from the indices",
295 ),
296 ConstraintSource::TypeLevelIf | ConstraintSource::Where => diag,
297 ConstraintSource::PipelineRegOffset { .. } => diag,
298 ConstraintSource::PipelineRegCount { reg, total } => Diagnostic::error(
299 total,
300 format!("Expected {expected} in this pipeline."),
301 )
302 .primary_label(format!("Expected {expected} stages"))
303 .secondary_label(reg, format!("This final register is number {got}")),
304 ConstraintSource::PipelineAvailDepth => diag,
305 }
306 }
307 UnificationError::Specific(e) => e,
308 }
309 })
310 }
311}
312
313fn add_known_type_context(
314 diag: Diagnostic,
315 unification_point: impl Into<FullSpan> + Clone,
316 failing: &TypeVar,
317 meta: bool,
318) -> Diagnostic {
319 match failing {
320 TypeVar::Known(k, _, _) => {
321 if FullSpan::from(k) != unification_point.clone().into() {
322 diag.secondary_label(
323 k,
324 format!("Type {} inferred here", failing.display_with_meta(meta)),
325 )
326 } else {
327 diag
328 }
329 }
330 TypeVar::Unknown(k, _, _, _) => {
331 if FullSpan::from(k) != unification_point.clone().into() {
332 diag.secondary_label(
333 k,
334 format!("Type {} inferred here", failing.display_with_meta(meta)),
335 )
336 } else {
337 diag
338 }
339 }
340 }
341}
342
343#[derive(Debug, PartialEq, Clone)]
344pub struct TypeMismatch {
345 pub e: UnificationTrace,
347 pub g: UnificationTrace,
349}
350impl TypeMismatch {
351 pub fn is_meta_error(&self) -> bool {
352 matches!(self.e.failing, TypeVar::Unknown(_, _, _, _))
353 || matches!(self.g.failing, TypeVar::Unknown(_, _, _, _))
354 }
355
356 pub fn display_e_g(&self) -> (String, String) {
357 let is_meta = self.is_meta_error();
358 (
359 self.e.display_with_meta(is_meta),
360 self.g.display_with_meta(is_meta),
361 )
362 }
363}
364impl std::fmt::Display for TypeMismatch {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 write!(f, "expected {}, got {}", self.e, self.g)
367 }
368}
369
370#[derive(Debug, PartialEq, Clone)]
371pub enum UnificationError {
372 Normal(TypeMismatch),
373 MetaMismatch(TypeMismatch),
374 Specific(spade_diagnostics::Diagnostic),
375 UnsatisfiedTraits {
376 var: TypeVar,
377 traits: Vec<Loc<TraitReq>>,
378 target_loc: Loc<()>,
379 },
380 FromConstraints {
381 expected: UnificationTrace,
382 got: UnificationTrace,
383 source: ConstraintSource,
384 loc: Loc<()>,
385 is_meta_error: bool,
386 },
387}
388
389pub type Result<T> = std::result::Result<T, Diagnostic>;
390
391pub fn error_pattern_type_mismatch(
392 reason: Loc<()>,
393) -> impl Fn(Diagnostic, TypeMismatch) -> Diagnostic {
394 move |diag,
395 TypeMismatch {
396 e: expected,
397 g: got,
398 }| {
399 diag.message(format!(
400 "Pattern type mismatch. Expected {expected} got {got}"
401 ))
402 .primary_label(format!("expected {expected}, got {got}"))
403 .secondary_label(reason, format!("because this has type {expected}"))
404 }
405}