1use itertools::Itertools;
2use std::collections::{BTreeSet, HashMap};
3
4use num::BigInt;
5use serde::{Deserialize, Serialize};
6use spade_common::{
7 id_tracker::ExprID,
8 location_info::{Loc, WithLocation},
9 name::NameID,
10};
11use spade_hir::TraitName;
12use spade_types::{meta_types::MetaType, KnownType};
13
14pub type TypeEquations = HashMap<TypedExpression, TypeVar>;
15
16#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
17pub struct TraitReq {
18 pub name: TraitName,
19 pub type_params: Vec<TypeVar>,
20}
21
22impl WithLocation for TraitReq {}
23
24impl TraitReq {
25 pub fn display_with_meta(&self, display_meta: bool) -> String {
26 if self.type_params.is_empty() {
27 format!("{}", self.name)
28 } else {
29 format!(
30 "{}<{}>",
31 self.name,
32 self.type_params
33 .iter()
34 .map(|t| format!("{}", t.display_with_meta(display_meta)))
35 .join(", ")
36 )
37 }
38 }
39}
40
41impl std::fmt::Display for TraitReq {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}", self.display_with_meta(false))
44 }
45}
46impl std::fmt::Debug for TraitReq {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 if self.type_params.is_empty() {
49 write!(f, "{}", self.name)
50 } else {
51 write!(
52 f,
53 "{}<{}>",
54 self.name,
55 self.type_params.iter().map(|t| format!("{t:?}")).join(", ")
56 )
57 }
58 }
59}
60
61#[derive(Clone, Serialize, Deserialize)]
62pub struct TraitList {
63 pub inner: Vec<Loc<TraitReq>>,
64}
65
66impl TraitList {
67 pub fn empty() -> Self {
68 Self { inner: vec![] }
69 }
70
71 pub fn from_vec(inner: Vec<Loc<TraitReq>>) -> Self {
72 Self { inner }
73 }
74
75 pub fn get_trait(&self, name: &TraitName) -> Option<&Loc<TraitReq>> {
76 self.inner.iter().find(|t| &t.name == name)
77 }
78
79 pub fn get_trait_with_type_params(
80 &self,
81 name: &TraitName,
82 type_params: &[TypeVar],
83 ) -> Option<&Loc<TraitReq>> {
84 self.inner
85 .iter()
86 .find(|t| &t.name == name && &t.type_params.as_slice() == &type_params)
87 }
88
89 pub fn extend(self, other: Self) -> Self {
90 let merged = self
91 .inner
92 .into_iter()
93 .chain(other.inner.into_iter())
94 .collect::<BTreeSet<_>>()
95 .into_iter()
96 .collect_vec();
97
98 TraitList { inner: merged }
99 }
100
101 pub fn display_with_meta(&self, display_meta: bool) -> String {
102 self.inner
103 .iter()
104 .map(|t| t.inner.display_with_meta(display_meta))
105 .join(" + ")
106 }
107}
108
109impl PartialEq for TraitList {
112 fn eq(&self, _other: &Self) -> bool {
113 true
114 }
115}
116impl Eq for TraitList {}
117impl std::hash::Hash for TraitList {
118 fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
119}
120impl PartialOrd for TraitList {
121 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122 Some(self.cmp(other))
123 }
124}
125impl Ord for TraitList {
126 fn cmp(&self, _other: &Self) -> std::cmp::Ordering {
127 std::cmp::Ordering::Equal
128 }
129}
130
131impl std::fmt::Display for TraitList {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "{}", self.display_with_meta(false))
134 }
135}
136impl std::fmt::Debug for TraitList {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write!(f, "{}", self.display_with_meta(true))
139 }
140}
141
142#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
148pub enum TypeVar {
149 Known(Loc<()>, KnownType, Vec<TypeVar>),
151 Unknown(Loc<()>, u64, TraitList, MetaType),
154}
155
156impl WithLocation for TypeVar {}
157
158impl TypeVar {
159 pub fn array(loc: Loc<()>, inner: TypeVar, size: TypeVar) -> Self {
160 TypeVar::Known(loc, KnownType::Array, vec![inner, size])
161 }
162
163 pub fn tuple(loc: Loc<()>, inner: Vec<TypeVar>) -> Self {
164 TypeVar::Known(loc, KnownType::Tuple, inner)
165 }
166
167 pub fn unit(loc: Loc<()>) -> Self {
168 TypeVar::Known(loc, KnownType::Tuple, Vec::new())
169 }
170
171 pub fn wire(loc: Loc<()>, inner: TypeVar) -> Self {
172 TypeVar::Known(loc, KnownType::Wire, vec![inner])
173 }
174
175 pub fn backward(loc: Loc<()>, inner: TypeVar) -> Self {
176 TypeVar::Known(
177 loc,
178 KnownType::Inverted,
179 vec![TypeVar::Known(loc, KnownType::Wire, vec![inner])],
180 )
181 }
182
183 pub fn inverted(loc: Loc<()>, inner: TypeVar) -> Self {
184 TypeVar::Known(loc, KnownType::Inverted, vec![inner])
185 }
186
187 pub fn expect_known<T, U, K, O>(&self, on_known: K, on_unknown: U) -> T
188 where
189 U: FnOnce() -> T,
190 K: FnOnce(&KnownType, &[TypeVar]) -> T,
191 {
192 match self {
193 TypeVar::Unknown(_, _, _, _) => on_unknown(),
194 TypeVar::Known(_, k, v) => on_known(k, v),
195 }
196 }
197
198 pub fn expect_named<T, U, K, O>(&self, on_named: K, on_unknown: U, on_other: O) -> T
199 where
200 U: FnOnce() -> T,
201 K: FnOnce(&NameID, &[TypeVar]) -> T,
202 O: FnOnce(&TypeVar) -> T,
203 {
204 match self {
205 TypeVar::Unknown(_, _, _, _) => on_unknown(),
206 TypeVar::Known(_, KnownType::Named(name), params) => on_named(name, params),
207 other => on_other(other),
208 }
209 }
210
211 pub fn expect_named_or_inverted<T, U, K, O>(
219 &self,
220 inverted_now: bool,
221 on_named: K,
222 on_unknown: U,
223 on_other: O,
224 ) -> T
225 where
226 U: FnOnce() -> T,
227 K: FnOnce(bool, &NameID, &[TypeVar]) -> T,
228 O: FnOnce(&TypeVar) -> T,
229 {
230 match self {
231 TypeVar::Unknown(_, _, _, _) => on_unknown(),
232 TypeVar::Known(_, KnownType::Inverted, params) => {
233 if params.len() != 1 {
234 panic!("Found wire with {} params", params.len())
235 }
236 params[0].expect_named_or_inverted(!inverted_now, on_named, on_unknown, on_other)
237 }
238 TypeVar::Known(_, KnownType::Named(name), params) => {
239 on_named(inverted_now, name, params)
240 }
241 other => on_other(other),
242 }
243 }
244
245 pub fn expect_specific_named<T, U, K, O>(
246 &self,
247 name: NameID,
248 on_correct: K,
249 on_unknown: U,
250 on_other: O,
251 ) -> T
252 where
253 U: FnOnce() -> T,
254 K: FnOnce(&[TypeVar]) -> T,
255 O: FnOnce(&TypeVar) -> T,
256 {
257 match self {
258 TypeVar::Unknown(_, _, _, _) => on_unknown(),
259 TypeVar::Known(_, k, v) if k == &KnownType::Named(name) => on_correct(v),
260 other => on_other(other),
261 }
262 }
263
264 pub fn expect_integer<T, U, K, O>(&self, on_integer: K, on_unknown: U, on_other: O) -> T
268 where
269 U: FnOnce() -> T,
270 K: FnOnce(BigInt) -> T,
271 O: FnOnce(&TypeVar) -> T,
272 {
273 match self {
274 TypeVar::Known(_, KnownType::Integer(size), params) => {
275 assert!(params.is_empty());
276 on_integer(size.clone())
277 }
278 TypeVar::Unknown(_, _, _, _) => on_unknown(),
279 other => on_other(other),
280 }
281 }
282
283 pub fn display_with_meta(&self, display_meta: bool) -> String {
284 match self {
285 TypeVar::Known(_, KnownType::Named(t), params) => {
286 let generics = if params.is_empty() {
287 String::new()
288 } else {
289 format!(
290 "<{}>",
291 params
292 .iter()
293 .map(|p| format!("{}", p.display_with_meta(display_meta)))
294 .collect::<Vec<_>>()
295 .join(", ")
296 )
297 };
298 format!("{}{}", t, generics)
299 }
300 TypeVar::Known(_, KnownType::Integer(inner), _) => {
301 format!("{inner}")
302 }
303 TypeVar::Known(_, KnownType::Bool(inner), _) => {
304 format!("{inner}")
305 }
306 TypeVar::Known(_, KnownType::Tuple, params) => {
307 format!(
308 "({})",
309 params
310 .iter()
311 .map(|t| format!("{}", t.display_with_meta(display_meta)))
312 .collect::<Vec<_>>()
313 .join(", ")
314 )
315 }
316 TypeVar::Known(_, KnownType::Array, params) => {
317 format!(
318 "[{}; {}]",
319 params[0].display_with_meta(display_meta),
320 params[1].display_with_meta(display_meta)
321 )
322 }
323 TypeVar::Known(_, KnownType::Wire, params) => {
324 format!("&{}", params[0].display_with_meta(display_meta))
325 }
326 TypeVar::Known(_, KnownType::Inverted, params) => {
327 format!("inv {}", params[0].display_with_meta(display_meta))
328 }
329 TypeVar::Unknown(_, _, traits, meta) if traits.inner.is_empty() => {
330 if display_meta {
331 format!("{meta}")
332 } else {
333 "_".to_string()
334 }
335 }
336 TypeVar::Unknown(_, _, traits, _meta) => {
338 format!(
339 "{}",
340 traits
341 .inner
342 .iter()
343 .map(|t| format!("{}", t.display_with_meta(display_meta)))
344 .join("+"),
345 )
346 }
347 }
348 }
349}
350
351impl std::fmt::Debug for TypeVar {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 match self {
354 TypeVar::Known(_, KnownType::Named(t), params) => {
355 let generics = if params.is_empty() {
356 String::new()
357 } else {
358 format!(
359 "<{}>",
360 params
361 .iter()
362 .map(|p| format!("{:?}", p))
363 .collect::<Vec<_>>()
364 .join(", ")
365 )
366 };
367 write!(f, "{}{}", t, generics)
368 }
369 TypeVar::Known(_, KnownType::Integer(inner), _) => {
370 write!(f, "{inner}")
371 }
372 TypeVar::Known(_, KnownType::Bool(inner), _) => {
373 write!(f, "{inner}")
374 }
375 TypeVar::Known(_, KnownType::Tuple, params) => {
376 write!(
377 f,
378 "({})",
379 params
380 .iter()
381 .map(|t| format!("{:?}", t))
382 .collect::<Vec<_>>()
383 .join(", ")
384 )
385 }
386 TypeVar::Known(_, KnownType::Array, params) => {
387 write!(f, "[{:?}; {:?}]", params[0], params[1])
388 }
389 TypeVar::Known(_, KnownType::Wire, params) => write!(f, "&{:?}", params[0]),
390 TypeVar::Known(_, KnownType::Inverted, params) => write!(f, "inv {:?}", params[0]),
391 TypeVar::Unknown(_, id, traits, meta_type) => write!(
392 f,
393 "t{id}({}, {meta_type})",
394 traits.inner.iter().map(|t| format!("{t}")).join("+")
395 ),
396 }
397 }
398}
399
400impl std::fmt::Display for TypeVar {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 write!(f, "{}", self.display_with_meta(false))
403 }
404}
405
406#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
407pub enum TypedExpression {
408 Id(ExprID),
409 Name(NameID),
410}
411
412impl WithLocation for TypedExpression {}
413
414impl std::fmt::Display for TypedExpression {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 match self {
417 TypedExpression::Id(i) => {
418 write!(f, "%{}", i.0)
419 }
420 TypedExpression::Name(p) => {
421 write!(f, "{}#{}", p, p.0)
422 }
423 }
424 }
425}