1use std::collections::BTreeMap;
2
3use itertools::Itertools;
4type HashMap<K, V> =
5 imbl::GenericHashMap<K, V, rustc_hash::FxBuildHasher, imbl::shared_ptr::DefaultSharedPtr>;
6
7use num::BigInt;
8use serde::{Deserialize, Serialize};
9use spade_common::{id_tracker::ExprID, location_info::Loc, name::NameID};
10use spade_types::{meta_types::MetaType, KnownType};
11
12use crate::{
13 traits::{TraitList, TraitReq},
14 HasType, TypeState,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
18pub struct TypeVarID {
19 pub inner: usize,
20 pub type_state_key: u64,
23}
24
25impl TypeVarID {
26 pub fn resolve(self, state: &TypeState) -> TypeVar {
27 assert!(
28 state.owned.keys.contains(&self.type_state_key),
29 "Type var key mismatch. Type states are being mixed incorrectly. Type state has {:?}, var has {}", state.owned.keys, self.type_state_key
30 );
31 let final_id = self.get_type(state);
33 state
34 .shared
35 .read_type_vars(|type_vars| type_vars.get(final_id.inner).unwrap().clone())
36 }
37
38 pub fn replace_inside(
39 self,
40 from: TypeVarID,
41 to: TypeVarID,
42 state: &mut TypeState,
43 ) -> TypeVarID {
44 if self.get_type(state) == from.get_type(state) {
45 to
46 } else {
47 let mut new = self.resolve(state).clone();
48 match &mut new {
49 TypeVar::Known(_, _known_type, params) => {
50 params
51 .iter_mut()
52 .for_each(|var| *var = var.replace_inside(from, to, state));
53 }
54 TypeVar::Unknown(_, _, trait_list, _) => {
55 trait_list.inner.iter_mut().for_each(|var| {
56 let TraitReq {
57 name: _,
58 type_params,
59 } = &mut var.inner;
60
61 type_params
62 .iter_mut()
63 .for_each(|t| *t = t.replace_inside(from, to, state));
64 })
65 }
66 };
67
68 new.insert(state)
72 }
73 }
74
75 pub fn display(self, type_state: &TypeState) -> String {
76 self.display_with_meta(false, type_state)
77 }
78
79 pub fn display_with_meta(self, meta: bool, type_state: &TypeState) -> String {
80 match self.resolve(type_state) {
81 TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
82 TypeVar::Known(_, KnownType::Named(t), params) => {
83 let generics = if params.is_empty() {
84 String::new()
85 } else {
86 format!(
87 "<{}>",
88 params
89 .iter()
90 .map(|p| format!("{}", p.display_with_meta(meta, type_state)))
91 .collect::<Vec<_>>()
92 .join(", ")
93 )
94 };
95 format!("{}{}", t, generics)
96 }
97 TypeVar::Known(_, KnownType::Integer(inner), _) => {
98 format!("{inner}")
99 }
100 TypeVar::Known(_, KnownType::Bool(inner), _) => {
101 format!("{inner}")
102 }
103 TypeVar::Known(_, KnownType::String(inner), _) => {
104 format!("{inner:?}")
105 }
106 TypeVar::Known(_, KnownType::Tuple, params) => {
107 format!(
108 "({})",
109 params
110 .iter()
111 .map(|t| format!("{}", t.display_with_meta(meta, type_state)))
112 .collect::<Vec<_>>()
113 .join(", ")
114 )
115 }
116 TypeVar::Known(_, KnownType::Array, params) => {
117 format!(
118 "[{}; {}]",
119 params[0].display_with_meta(meta, type_state),
120 params[1].display_with_meta(meta, type_state)
121 )
122 }
123 TypeVar::Known(_, KnownType::Wire, params) => {
124 format!("&{}", params[0].display_with_meta(meta, type_state))
125 }
126 TypeVar::Known(_, KnownType::Inverted, params) => {
127 format!("inv {}", params[0].display_with_meta(meta, type_state))
128 }
129 TypeVar::Unknown(_, _, traits, meta_type) => match meta_type {
130 MetaType::Type => {
131 if !traits.inner.is_empty() {
132 traits
133 .inner
134 .iter()
135 .map(|t| t.display_with_meta(meta, type_state))
136 .join(" + ")
137 } else {
138 "_".to_string()
139 }
140 }
141 _ => {
142 if meta {
143 format!("{}", meta_type)
144 } else {
145 format!("_")
146 }
147 }
148 },
149 }
150 }
151
152 pub fn debug_resolve(self, state: &TypeState) -> TypeVarString {
153 match self.resolve(state) {
154 TypeVar::Known(_, base, params) => {
155 let params = if params.is_empty() {
156 "".to_string()
157 } else {
158 format!(
159 "<{}>",
160 params.iter().map(|t| t.debug_resolve(state).0).join(", ")
161 )
162 };
163 let base = match base {
164 KnownType::Named(name_id) => format!("{name_id}"),
165 KnownType::Integer(big_int) => format!("{big_int}"),
166 KnownType::Bool(val) => format!("{val}"),
167 KnownType::String(val) => format!("{val:?}"),
168 KnownType::Tuple => format!("Tuple"),
169 KnownType::Array => format!("Array"),
170 KnownType::Wire => format!("&"),
171 KnownType::Inverted => format!("inv &"),
172 KnownType::Error => format!("{{error}}"),
173 };
174 TypeVarString(format!("{base}{params}"), self)
175 }
176 TypeVar::Unknown(_, id, traits, _meta_type) => {
177 let traits = if traits.inner.is_empty() {
178 "".to_string()
179 } else {
180 format!(
181 "({})",
182 traits
183 .inner
184 .iter()
185 .map(|t| t.debug_display(state))
186 .join(" + ")
187 )
188 };
189 TypeVarString(format!("t{id}{traits}"), self)
190 }
191 }
192 }
193}
194
195#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, PartialOrd, Ord)]
198pub struct TemplateTypeVarID {
199 inner: TypeVarID,
200}
201
202impl TemplateTypeVarID {
203 pub fn new(inner: TypeVarID) -> Self {
204 Self { inner }
205 }
206
207 pub fn make_copy(&self, state: &mut TypeState) -> TypeVarID {
208 self.make_copy_with_mapping(state, &mut BTreeMap::new())
209 }
210
211 pub fn make_copy_with_mapping(
212 &self,
213 state: &mut TypeState,
214 mapped: &mut BTreeMap<TemplateTypeVarID, TypeVarID>,
215 ) -> TypeVarID {
216 if let Some(prev) = mapped.get(self) {
217 return *prev;
218 }
219
220 let new = match self.inner.resolve(state).clone() {
221 TypeVar::Known(loc, base, params) => TypeVar::Known(
222 loc,
223 base,
224 params
225 .into_iter()
226 .map(|p| TemplateTypeVarID { inner: p }.make_copy_with_mapping(state, mapped))
227 .collect(),
228 ),
229 TypeVar::Unknown(loc, id, TraitList { inner: tl }, meta_type) => TypeVar::Unknown(
230 loc,
231 id,
232 TraitList {
233 inner: tl
234 .into_iter()
235 .map(|loc| {
236 loc.map(|req| TraitReq {
237 name: req.name,
238 type_params: req
239 .type_params
240 .into_iter()
241 .map(|p| {
242 TemplateTypeVarID { inner: p }
243 .make_copy_with_mapping(state, mapped)
244 })
245 .collect(),
246 })
247 })
248 .collect(),
249 },
250 meta_type,
251 ),
252 };
253 let result = state.add_type_var(new);
254 mapped.insert(*self, result);
255 result
256 }
257}
258
259pub type TypeEquations = HashMap<TypedExpression, TypeVarID>;
260
261#[derive(Debug, Clone)]
263pub struct TypeVarString(pub String, pub TypeVarID);
264
265impl std::fmt::Display for TypeVarString {
266 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267 write!(f, "{}", self.0)
268 }
269}
270
271#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Debug)]
277pub enum TypeVar {
278 Known(Loc<()>, KnownType, Vec<TypeVarID>),
280 Unknown(Loc<()>, u64, TraitList, MetaType),
283}
284
285impl TypeVar {
286 pub fn into_known(&self, type_state: &TypeState) -> Option<KnownTypeVar> {
287 match self {
288 TypeVar::Known(loc, base, params) => Some(KnownTypeVar(
289 loc.clone(),
290 base.clone(),
291 params
292 .iter()
293 .map(|t| t.resolve(type_state).into_known(type_state))
294 .collect::<Option<_>>()?,
295 )),
296 TypeVar::Unknown(_, _, _, _) => None,
297 }
298 }
299
300 pub fn insert(self, into: &mut TypeState) -> TypeVarID {
301 into.add_type_var(self)
302 }
303
304 pub fn array(loc: Loc<()>, inner: TypeVarID, size: TypeVarID) -> Self {
305 TypeVar::Known(loc, KnownType::Array, vec![inner, size])
306 }
307
308 pub fn tuple(loc: Loc<()>, inner: Vec<TypeVarID>) -> Self {
309 TypeVar::Known(loc, KnownType::Tuple, inner)
310 }
311
312 pub fn unit(loc: Loc<()>) -> Self {
313 TypeVar::Known(loc, KnownType::Tuple, Vec::new())
314 }
315
316 pub fn wire(loc: Loc<()>, inner: TypeVarID) -> Self {
317 TypeVar::Known(loc, KnownType::Wire, vec![inner])
318 }
319
320 pub fn backward(loc: Loc<()>, inner: TypeVarID, type_state: &mut TypeState) -> Self {
321 TypeVar::Known(
322 loc,
323 KnownType::Inverted,
324 vec![type_state.add_type_var(TypeVar::Known(loc, KnownType::Wire, vec![inner]))],
325 )
326 }
327
328 pub fn inverted(loc: Loc<()>, inner: TypeVarID) -> Self {
329 TypeVar::Known(loc, KnownType::Inverted, vec![inner])
330 }
331
332 pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
333 where
334 U: FnOnce() -> T,
335 K: FnOnce(&KnownType, &[TypeVarID]) -> T,
336 {
337 match self {
338 TypeVar::Unknown(_, _, _, _) => on_unknown(),
339 TypeVar::Known(_, k, v) => on_known(k, v),
340 }
341 }
342
343 pub fn expect_named<T, E, U, K, O>(
344 &self,
345 on_named: K,
346 on_unknown: U,
347 on_other: O,
348 on_error: E,
349 ) -> T
350 where
351 U: FnOnce() -> T,
352 K: FnOnce(&NameID, &[TypeVarID]) -> T,
353 E: FnOnce() -> T,
354 O: FnOnce(&TypeVar) -> T,
355 {
356 match self {
357 TypeVar::Unknown(_, _, _, _) => on_unknown(),
358 TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
359 TypeVar::Known(_, KnownType::Error, _) => on_error(),
360 other => on_other(other),
361 }
362 }
363
364 pub fn resolve_named_or_inverted(
368 &self,
369 inverted_now: bool,
370 type_state: &TypeState,
371 ) -> ResolvedNamedOrInverted {
372 match self {
373 TypeVar::Unknown(_, _, _, _) => ResolvedNamedOrInverted::Unknown,
374 TypeVar::Known(_, KnownType::Inverted, params) => {
375 if params.len() != 1 {
376 panic!("Found wire with {} params", params.len())
377 }
378 params[0]
379 .resolve(type_state)
380 .resolve_named_or_inverted(!inverted_now, type_state)
381 }
382 TypeVar::Known(_, KnownType::Named(name), params) => {
383 ResolvedNamedOrInverted::Named(inverted_now, name.clone(), params.clone())
384 }
385 _ => ResolvedNamedOrInverted::Other,
386 }
387 }
388
389 pub fn expect_specific_named<T, U, K, O>(
390 &self,
391 name: NameID,
392 on_correct: K,
393 on_unknown: U,
394 on_other: O,
395 ) -> T
396 where
397 U: FnOnce() -> T,
398 K: FnOnce(&[TypeVarID]) -> T,
399 O: FnOnce(&TypeVar) -> T,
400 {
401 match self {
402 TypeVar::Unknown(_, _, _, _) => on_unknown(),
403 TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
404 other => on_other(other),
405 }
406 }
407
408 pub fn expect_integer<T, E, U, K, O>(
412 &self,
413 on_integer: K,
414 on_unknown: U,
415 on_other: O,
416 on_error: E,
417 ) -> T
418 where
419 U: FnOnce() -> T,
420 E: FnOnce() -> T,
421 K: FnOnce(BigInt) -> T,
422 O: FnOnce(&TypeVar) -> T,
423 {
424 match self {
425 TypeVar::Known(_, KnownType::Integer(size), params) => {
426 assert!(params.is_empty());
427 on_integer(size.clone())
428 }
429 TypeVar::Known(_, KnownType::Error, _) => on_error(),
430 TypeVar::Unknown(_, _, _, _) => on_unknown(),
431 other => on_other(other),
432 }
433 }
434
435 pub fn expect_string<T, E, U, K, O>(
439 &self,
440 on_string: K,
441 on_unknown: U,
442 on_other: O,
443 on_error: E,
444 ) -> T
445 where
446 U: FnOnce() -> T,
447 E: FnOnce() -> T,
448 K: FnOnce(String) -> T,
449 O: FnOnce(&TypeVar) -> T,
450 {
451 match self {
452 TypeVar::Known(_, KnownType::String(val), params) => {
453 assert!(params.is_empty());
454 on_string(val.clone())
455 }
456 TypeVar::Known(_, KnownType::Error, _) => on_error(),
457 TypeVar::Unknown(_, _, _, _) => on_unknown(),
458 other => on_other(other),
459 }
460 }
461
462 pub fn display(&self, type_state: &TypeState) -> String {
463 self.display_with_meta(false, type_state)
464 }
465
466 pub fn display_with_meta(&self, display_meta: bool, type_state: &TypeState) -> String {
467 match self {
468 TypeVar::Known(_, KnownType::Error, _) => "{unknown}".to_string(),
469 TypeVar::Known(_, KnownType::Named(t), params) => {
470 let generics = if params.is_empty() {
471 String::new()
472 } else {
473 format!(
474 "<{}>",
475 params
476 .iter()
477 .map(|p| format!("{}", p.display_with_meta(display_meta, type_state)))
478 .collect::<Vec<_>>()
479 .join(", ")
480 )
481 };
482 format!("{}{}", t, generics)
483 }
484 TypeVar::Known(_, KnownType::Integer(inner), _) => {
485 format!("{inner}")
486 }
487 TypeVar::Known(_, KnownType::Bool(inner), _) => {
488 format!("{inner}")
489 }
490 TypeVar::Known(_, KnownType::String(inner), _) => {
491 format!("{inner:?}")
492 }
493 TypeVar::Known(_, KnownType::Tuple, params) => {
494 format!(
495 "({})",
496 params
497 .iter()
498 .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
499 .collect::<Vec<_>>()
500 .join(", ")
501 )
502 }
503 TypeVar::Known(_, KnownType::Array, params) => {
504 format!(
505 "[{}; {}]",
506 params[0].display_with_meta(display_meta, type_state),
507 params[1].display_with_meta(display_meta, type_state)
508 )
509 }
510 TypeVar::Known(_, KnownType::Wire, params) => {
511 format!("&{}", params[0].display_with_meta(display_meta, type_state))
512 }
513 TypeVar::Known(_, KnownType::Inverted, params) => {
514 format!(
515 "inv {}",
516 params[0].display_with_meta(display_meta, type_state)
517 )
518 }
519 TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
520 if display_meta {
521 format!("{meta}")
522 } else {
523 "_".to_string()
524 }
525 }
526 TypeVar::Unknown(_, _, traits, _meta) => {
528 format!(
529 "{}",
530 traits
531 .inner
532 .iter()
533 .map(|t| format!("{}", t.display_with_meta(display_meta, type_state)))
534 .join("+"),
535 )
536 }
537 }
538 }
539}
540
541#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
542pub struct KnownTypeVar(pub Loc<()>, pub KnownType, pub Vec<KnownTypeVar>);
543
544impl KnownTypeVar {
545 pub fn insert(&self, type_state: &mut TypeState) -> TypeVarID {
546 let KnownTypeVar(loc, base, params) = self;
547 TypeVar::Known(
548 loc.clone(),
549 base.clone(),
550 params.into_iter().map(|p| p.insert(type_state)).collect(),
551 )
552 .insert(type_state)
553 }
554}
555
556impl std::fmt::Display for KnownTypeVar {
557 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
558 let KnownTypeVar(_, base, params) = self;
559
560 match base {
561 KnownType::Error => {
562 write!(f, "{{unknown}}")
563 }
564 KnownType::Named(name_id) => {
565 write!(f, "{name_id}")?;
566 if !params.is_empty() {
567 write!(f, "<{}>", params.iter().map(|t| format!("{t}")).join(", "))?;
568 }
569 Ok(())
570 }
571 KnownType::Integer(val) => write!(f, "{val}"),
572 KnownType::Bool(val) => write!(f, "{val}"),
573 KnownType::String(val) => write!(f, "{val:?}"),
574 KnownType::Tuple => {
575 write!(f, "({})", params.iter().map(|t| format!("{t}")).join(", "))
576 }
577 KnownType::Array => write!(f, "[{}; {}]", params[0], params[1]),
578 KnownType::Wire => write!(f, "&{}", params[0]),
579 KnownType::Inverted => write!(f, "inv {}", params[0]),
580 }
581 }
582}
583
584pub enum ResolvedNamedOrInverted {
585 Unknown,
586 Named(bool, NameID, Vec<TypeVarID>),
587 Other,
588}
589
590#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
591pub enum TypedExpression {
592 Id(ExprID),
593 Name(NameID),
594}
595
596impl std::fmt::Display for TypedExpression {
597 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
598 match self {
599 TypedExpression::Id(i) => {
600 write!(f, "%{}", i.0)
601 }
602 TypedExpression::Name(p) => {
603 write!(f, "{}#{}", p, p.0)
604 }
605 }
606 }
607}