1use itertools::Itertools;
2use std::collections::{BTreeMap, HashMap};
3
4use num::BigInt;
5use serde::{Deserialize, Serialize};
6use spade_common::{id_tracker::ExprID, location_info::Loc, name::NameID};
7use spade_types::{meta_types::MetaType, KnownType};
8
9use crate::{
10 traits::{TraitList, TraitReq},
11 HasType, TypeState,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
15pub struct TypeVarID {
16 pub inner: usize,
17 pub type_state_key: u64,
20}
21
22impl TypeVarID {
23 pub fn resolve(self, state: &TypeState) -> &TypeVar {
24 assert!(
25 state.keys.contains(&self.type_state_key),
26 "Type var key mismatch. Type states are being mixed incorrectly. Type state has {:?}, var has {}", state.keys, self.type_state_key
27 );
28 let final_id = self.get_type(state);
30 state.type_vars.get(final_id.inner).unwrap()
31 }
32
33 pub fn replace_inside(
34 self,
35 from: TypeVarID,
36 to: TypeVarID,
37 state: &mut TypeState,
38 ) -> TypeVarID {
39 if self.get_type(state) == from.get_type(state) {
40 to
41 } else {
42 let mut new = self.resolve(state).clone();
43 match &mut new {
44 TypeVar::Known(_, _known_type, params) => {
45 params
46 .iter_mut()
47 .for_each(|var| *var = var.replace_inside(from, to, state));
48 }
49 TypeVar::Unknown(_, _, trait_list, _) => {
50 trait_list.inner.iter_mut().for_each(|var| {
51 let TraitReq {
52 name: _,
53 type_params,
54 } = &mut var.inner;
55
56 type_params
57 .iter_mut()
58 .for_each(|t| *t = t.replace_inside(from, to, state));
59 })
60 }
61 };
62
63 new.insert(state)
67 }
68 }
69
70 pub fn display(self, type_state: &TypeState) -> String {
71 self.display_with_meta(false, type_state)
72 }
73
74 pub fn display_with_meta(self, meta: bool, type_state: &TypeState) -> String {
75 match self.resolve(type_state) {
76 TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
77 TypeVar::Known(_, KnownType::Named(t), params) => {
78 let generics = if params.is_empty() {
79 String::new()
80 } else {
81 format!(
82 "<{}>",
83 params
84 .iter()
85 .map(|p| format!("{}", p.display_with_meta(meta, type_state)))
86 .collect::<Vec<_>>()
87 .join(", ")
88 )
89 };
90 format!("{}{}", t, generics)
91 }
92 TypeVar::Known(_, KnownType::Integer(inner), _) => {
93 format!("{inner}")
94 }
95 TypeVar::Known(_, KnownType::Bool(inner), _) => {
96 format!("{inner}")
97 }
98 TypeVar::Known(_, KnownType::String(inner), _) => {
99 format!("{inner:?}")
100 }
101 TypeVar::Known(_, KnownType::Tuple, params) => {
102 format!(
103 "({})",
104 params
105 .iter()
106 .map(|t| format!("{}", t.display_with_meta(meta, type_state)))
107 .collect::<Vec<_>>()
108 .join(", ")
109 )
110 }
111 TypeVar::Known(_, KnownType::Array, params) => {
112 format!(
113 "[{}; {}]",
114 params[0].display_with_meta(meta, type_state),
115 params[1].display_with_meta(meta, type_state)
116 )
117 }
118 TypeVar::Known(_, KnownType::Wire, params) => {
119 format!("&{}", params[0].display_with_meta(meta, type_state))
120 }
121 TypeVar::Known(_, KnownType::Inverted, params) => {
122 format!("inv {}", params[0].display_with_meta(meta, type_state))
123 }
124 TypeVar::Unknown(_, _, traits, meta_type) => match meta_type {
125 MetaType::Type => {
126 if !traits.inner.is_empty() {
127 traits
128 .inner
129 .iter()
130 .map(|t| t.display_with_meta(meta, type_state))
131 .join(" + ")
132 } else {
133 "_".to_string()
134 }
135 }
136 _ => {
137 if meta {
138 format!("{}", meta_type)
139 } else {
140 format!("_")
141 }
142 }
143 },
144 }
145 }
146
147 pub fn debug_resolve(self, state: &TypeState) -> TypeVarString {
148 match self.resolve(state) {
149 TypeVar::Known(_, base, params) => {
150 let params = if params.is_empty() {
151 "".to_string()
152 } else {
153 format!(
154 "<{}>",
155 params.iter().map(|t| t.debug_resolve(state).0).join(", ")
156 )
157 };
158 let base = match base {
159 KnownType::Named(name_id) => format!("{name_id}"),
160 KnownType::Integer(big_int) => format!("{big_int}"),
161 KnownType::Bool(val) => format!("{val}"),
162 KnownType::String(val) => format!("{val:?}"),
163 KnownType::Tuple => format!("Tuple"),
164 KnownType::Array => format!("Array"),
165 KnownType::Wire => format!("&"),
166 KnownType::Inverted => format!("inv &"),
167 KnownType::Error => format!("{{error}}"),
168 };
169 TypeVarString(format!("{base}{params}"), self)
170 }
171 TypeVar::Unknown(_, id, traits, _meta_type) => {
172 let traits = if traits.inner.is_empty() {
173 "".to_string()
174 } else {
175 format!(
176 "({})",
177 traits
178 .inner
179 .iter()
180 .map(|t| t.debug_display(state))
181 .join(" + ")
182 )
183 };
184 TypeVarString(format!("t{id}{traits}"), self)
185 }
186 }
187 }
188}
189
190#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
193pub struct TemplateTypeVarID {
194 inner: TypeVarID,
195}
196
197impl TemplateTypeVarID {
198 pub fn new(inner: TypeVarID) -> Self {
199 Self { inner }
200 }
201
202 pub fn make_copy(&self, state: &mut TypeState) -> TypeVarID {
203 self.make_copy_with_mapping(state, &mut BTreeMap::new())
204 }
205
206 pub fn make_copy_with_mapping(
207 &self,
208 state: &mut TypeState,
209 mapped: &mut BTreeMap<TemplateTypeVarID, TypeVarID>,
210 ) -> TypeVarID {
211 if let Some(prev) = mapped.get(self) {
212 return *prev;
213 }
214
215 let new = match self.inner.resolve(state).clone() {
216 TypeVar::Known(loc, base, params) => TypeVar::Known(
217 loc,
218 base,
219 params
220 .into_iter()
221 .map(|p| TemplateTypeVarID { inner: p }.make_copy_with_mapping(state, mapped))
222 .collect(),
223 ),
224 TypeVar::Unknown(loc, id, TraitList { inner: tl }, meta_type) => TypeVar::Unknown(
225 loc,
226 id,
227 TraitList {
228 inner: tl
229 .into_iter()
230 .map(|loc| {
231 loc.map(|req| TraitReq {
232 name: req.name,
233 type_params: req
234 .type_params
235 .into_iter()
236 .map(|p| {
237 TemplateTypeVarID { inner: p }
238 .make_copy_with_mapping(state, mapped)
239 })
240 .collect(),
241 })
242 })
243 .collect(),
244 },
245 meta_type,
246 ),
247 };
248 let result = state.add_type_var(new);
249 mapped.insert(*self, result);
250 result
251 }
252}
253
254pub type TypeEquations = HashMap<TypedExpression, TypeVarID>;
255
256#[derive(Debug, Clone)]
258pub struct TypeVarString(pub String, pub TypeVarID);
259
260impl std::fmt::Display for TypeVarString {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 write!(f, "{}", self.0)
263 }
264}
265
266#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Debug)]
272pub enum TypeVar {
273 Known(Loc<()>, KnownType, Vec<TypeVarID>),
275 Unknown(Loc<()>, u64, TraitList, MetaType),
278}
279
280impl TypeVar {
281 pub fn into_known(&self, type_state: &TypeState) -> Option<KnownTypeVar> {
282 match self {
283 TypeVar::Known(loc, base, params) => Some(KnownTypeVar(
284 loc.clone(),
285 base.clone(),
286 params
287 .iter()
288 .map(|t| t.resolve(type_state).into_known(type_state))
289 .collect::<Option<_>>()?,
290 )),
291 TypeVar::Unknown(_, _, _, _) => None,
292 }
293 }
294
295 pub fn insert(self, into: &mut TypeState) -> TypeVarID {
296 into.add_type_var(self)
297 }
298
299 pub fn array(loc: Loc<()>, inner: TypeVarID, size: TypeVarID) -> Self {
300 TypeVar::Known(loc, KnownType::Array, vec![inner, size])
301 }
302
303 pub fn tuple(loc: Loc<()>, inner: Vec<TypeVarID>) -> Self {
304 TypeVar::Known(loc, KnownType::Tuple, inner)
305 }
306
307 pub fn unit(loc: Loc<()>) -> Self {
308 TypeVar::Known(loc, KnownType::Tuple, Vec::new())
309 }
310
311 pub fn wire(loc: Loc<()>, inner: TypeVarID) -> Self {
312 TypeVar::Known(loc, KnownType::Wire, vec![inner])
313 }
314
315 pub fn backward(loc: Loc<()>, inner: TypeVarID, type_state: &mut TypeState) -> Self {
316 TypeVar::Known(
317 loc,
318 KnownType::Inverted,
319 vec![type_state.add_type_var(TypeVar::Known(loc, KnownType::Wire, vec![inner]))],
320 )
321 }
322
323 pub fn inverted(loc: Loc<()>, inner: TypeVarID) -> Self {
324 TypeVar::Known(loc, KnownType::Inverted, vec![inner])
325 }
326
327 pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
328 where
329 U: FnOnce() -> T,
330 K: FnOnce(&KnownType, &[TypeVarID]) -> T,
331 {
332 match self {
333 TypeVar::Unknown(_, _, _, _) => on_unknown(),
334 TypeVar::Known(_, k, v) => on_known(k, v),
335 }
336 }
337
338 pub fn expect_named<T, E, U, K, O>(
339 &self,
340 on_named: K,
341 on_unknown: U,
342 on_other: O,
343 on_error: E,
344 ) -> T
345 where
346 U: FnOnce() -> T,
347 K: FnOnce(&NameID, &[TypeVarID]) -> T,
348 E: FnOnce() -> T,
349 O: FnOnce(&TypeVar) -> T,
350 {
351 match self {
352 TypeVar::Unknown(_, _, _, _) => on_unknown(),
353 TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
354 TypeVar::Known(_, KnownType::Error, _) => on_error(),
355 other => on_other(other),
356 }
357 }
358
359 pub fn resolve_named_or_inverted(
363 &self,
364 inverted_now: bool,
365 type_state: &TypeState,
366 ) -> ResolvedNamedOrInverted {
367 match self {
368 TypeVar::Unknown(_, _, _, _) => ResolvedNamedOrInverted::Unknown,
369 TypeVar::Known(_, KnownType::Inverted, params) => {
370 if params.len() != 1 {
371 panic!("Found wire with {} params", params.len())
372 }
373 params[0]
374 .resolve(type_state)
375 .resolve_named_or_inverted(!inverted_now, type_state)
376 }
377 TypeVar::Known(_, KnownType::Named(name), params) => {
378 ResolvedNamedOrInverted::Named(inverted_now, name.clone(), params.clone())
379 }
380 _ => ResolvedNamedOrInverted::Other,
381 }
382 }
383
384 pub fn expect_specific_named<T, U, K, O>(
385 &self,
386 name: NameID,
387 on_correct: K,
388 on_unknown: U,
389 on_other: O,
390 ) -> T
391 where
392 U: FnOnce() -> T,
393 K: FnOnce(&[TypeVarID]) -> T,
394 O: FnOnce(&TypeVar) -> T,
395 {
396 match self {
397 TypeVar::Unknown(_, _, _, _) => on_unknown(),
398 TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
399 other => on_other(other),
400 }
401 }
402
403 pub fn expect_integer<T, E, U, K, O>(
407 &self,
408 on_integer: K,
409 on_unknown: U,
410 on_other: O,
411 on_error: E,
412 ) -> T
413 where
414 U: FnOnce() -> T,
415 E: FnOnce() -> T,
416 K: FnOnce(BigInt) -> T,
417 O: FnOnce(&TypeVar) -> T,
418 {
419 match self {
420 TypeVar::Known(_, KnownType::Integer(size), params) => {
421 assert!(params.is_empty());
422 on_integer(size.clone())
423 }
424 TypeVar::Known(_, KnownType::Error, _) => on_error(),
425 TypeVar::Unknown(_, _, _, _) => on_unknown(),
426 other => on_other(other),
427 }
428 }
429
430 pub fn expect_string<T, E, U, K, O>(
434 &self,
435 on_string: K,
436 on_unknown: U,
437 on_other: O,
438 on_error: E,
439 ) -> T
440 where
441 U: FnOnce() -> T,
442 E: FnOnce() -> T,
443 K: FnOnce(String) -> T,
444 O: FnOnce(&TypeVar) -> T,
445 {
446 match self {
447 TypeVar::Known(_, KnownType::String(val), params) => {
448 assert!(params.is_empty());
449 on_string(val.clone())
450 }
451 TypeVar::Known(_, KnownType::Error, _) => on_error(),
452 TypeVar::Unknown(_, _, _, _) => on_unknown(),
453 other => on_other(other),
454 }
455 }
456
457 pub fn display(&self, type_state: &TypeState) -> String {
458 self.display_with_meta(false, type_state)
459 }
460
461 pub fn display_with_meta(&self, display_meta: bool, type_state: &TypeState) -> String {
462 match self {
463 TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
464 TypeVar::Known(_, KnownType::Named(t), params) => {
465 let generics = if params.is_empty() {
466 String::new()
467 } else {
468 format!(
469 "<{}>",
470 params
471 .iter()
472 .map(|p| format!("{}", p.display_with_meta(display_meta, type_state)))
473 .collect::<Vec<_>>()
474 .join(", ")
475 )
476 };
477 format!("{}{}", t, generics)
478 }
479 TypeVar::Known(_, KnownType::Integer(inner), _) => {
480 format!("{inner}")
481 }
482 TypeVar::Known(_, KnownType::Bool(inner), _) => {
483 format!("{inner}")
484 }
485 TypeVar::Known(_, KnownType::String(inner), _) => {
486 format!("{inner:?}")
487 }
488 TypeVar::Known(_, KnownType::Tuple, params) => {
489 format!(
490 "({})",
491 params
492 .iter()
493 .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
494 .collect::<Vec<_>>()
495 .join(", ")
496 )
497 }
498 TypeVar::Known(_, KnownType::Array, params) => {
499 format!(
500 "[{}; {}]",
501 params[0].display_with_meta(display_meta, type_state),
502 params[1].display_with_meta(display_meta, type_state)
503 )
504 }
505 TypeVar::Known(_, KnownType::Wire, params) => {
506 format!("&{}", params[0].display_with_meta(display_meta, type_state))
507 }
508 TypeVar::Known(_, KnownType::Inverted, params) => {
509 format!(
510 "inv {}",
511 params[0].display_with_meta(display_meta, type_state)
512 )
513 }
514 TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
515 if display_meta {
516 format!("{meta}")
517 } else {
518 "_".to_string()
519 }
520 }
521 TypeVar::Unknown(_, _, traits, _meta) => {
523 format!(
524 "{}",
525 traits
526 .inner
527 .iter()
528 .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
529 .join("+"),
530 )
531 }
532 }
533 }
534}
535
536#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
537pub struct KnownTypeVar(pub Loc<()>, pub KnownType, pub Vec<KnownTypeVar>);
538
539impl KnownTypeVar {
540 pub fn insert(&self, type_state: &mut TypeState) -> TypeVarID {
541 let KnownTypeVar(loc, base, params) = self;
542 TypeVar::Known(
543 loc.clone(),
544 base.clone(),
545 params.into_iter().map(|p| p.insert(type_state)).collect(),
546 )
547 .insert(type_state)
548 }
549}
550
551impl std::fmt::Display for KnownTypeVar {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 let KnownTypeVar(_, base, params) = self;
554
555 match base {
556 KnownType::Error => {
557 write!(f, "{{unknown}}")
558 }
559 KnownType::Named(name_id) => {
560 write!(f, "{name_id}")?;
561 if !params.is_empty() {
562 write!(f, "<{}>", params.iter().map(|t| format!("{t}")).join(", "))?;
563 }
564 Ok(())
565 }
566 KnownType::Integer(val) => write!(f, "{val}"),
567 KnownType::Bool(val) => write!(f, "{val}"),
568 KnownType::String(val) => write!(f, "{val:?}"),
569 KnownType::Tuple => {
570 write!(f, "({})", params.iter().map(|t| format!("{t}")).join(", "))
571 }
572 KnownType::Array => write!(f, "[{}; {}]", params[0], params[1]),
573 KnownType::Wire => write!(f, "&{}", params[0]),
574 KnownType::Inverted => write!(f, "inv {}", params[0]),
575 }
576 }
577}
578
579pub enum ResolvedNamedOrInverted {
580 Unknown,
581 Named(bool, NameID, Vec<TypeVarID>),
582 Other,
583}
584
585#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
586pub enum TypedExpression {
587 Id(ExprID),
588 Name(NameID),
589}
590
591impl std::fmt::Display for TypedExpression {
592 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 match self {
594 TypedExpression::Id(i) => {
595 write!(f, "%{}", i.0)
596 }
597 TypedExpression::Name(p) => {
598 write!(f, "{}#{}", p, p.0)
599 }
600 }
601 }
602}