1use std::collections::HashMap;
2
3use hir::symbol_table::SymbolTable;
4use hir::{Parameter, TypeExpression, TypeSpec};
5use spade_common::id_tracker::ExprID;
6use spade_common::location_info::{Loc, WithLocation};
7use spade_common::name::NameID;
8use spade_diagnostics::Diagnostic;
9use spade_hir::{self as hir, ConstGenericWithId};
10use spade_hir::{TypeDeclaration, TypeList};
11use spade_types::{ConcreteType, KnownType};
12
13use crate::equation::{TypeVar, TypeVarID, TypedExpression};
14use crate::TypeState;
15
16pub trait HasConcreteType {
17 fn into_typed_expression(&self) -> Loc<TypedExpression>;
18}
19
20impl<T> HasConcreteType for &mut T
21where
22 T: HasConcreteType,
23{
24 fn into_typed_expression(&self) -> Loc<TypedExpression> {
25 (**self).into_typed_expression()
26 }
27}
28
29impl<T> HasConcreteType for &T
30where
31 T: HasConcreteType,
32{
33 fn into_typed_expression(&self) -> Loc<TypedExpression> {
34 (*self).into_typed_expression()
35 }
36}
37
38impl<T> HasConcreteType for Box<T>
39where
40 T: HasConcreteType,
41{
42 fn into_typed_expression(&self) -> Loc<TypedExpression> {
43 self.as_ref().into_typed_expression()
44 }
45}
46
47impl HasConcreteType for Loc<ExprID> {
48 fn into_typed_expression(&self) -> Loc<TypedExpression> {
49 TypedExpression::Id(self.inner).at_loc(self)
50 }
51}
52
53impl HasConcreteType for Loc<hir::Expression> {
54 fn into_typed_expression(&self) -> Loc<TypedExpression> {
55 TypedExpression::Id(self.id).at_loc(self)
56 }
57}
58
59impl HasConcreteType for Loc<hir::Pattern> {
60 fn into_typed_expression(&self) -> Loc<TypedExpression> {
61 TypedExpression::Id(self.id).at_loc(self)
62 }
63}
64impl HasConcreteType for Loc<ConstGenericWithId> {
65 fn into_typed_expression(&self) -> Loc<TypedExpression> {
66 TypedExpression::Id(self.id).at_loc(self)
67 }
68}
69
70impl HasConcreteType for Loc<NameID> {
71 fn into_typed_expression(&self) -> Loc<TypedExpression> {
72 TypedExpression::Name(self.inner.clone()).at_loc(self)
73 }
74}
75
76impl TypeState {
77 pub fn type_decl_to_concrete(
78 decl: &TypeDeclaration,
79 type_list: &TypeList,
80 params: Vec<ConcreteType>,
81 invert: bool,
82 ) -> ConcreteType {
83 assert!(
86 params.len() == decl.generic_args.len(),
87 "Too few type decl params in {:?}",
88 decl
89 );
90
91 let generic_subs = decl
92 .generic_args
93 .iter()
94 .zip(params.iter())
95 .map(|(lhs, rhs)| (lhs.name_id(), rhs))
96 .collect::<HashMap<_, _>>();
97
98 match &decl.kind {
99 hir::TypeDeclKind::Enum(e) => {
100 let options = e
101 .options
102 .iter()
103 .map(|(name, args)| {
104 let args = args
105 .0
106 .iter()
107 .map(|arg| {
108 (
109 arg.name.inner.clone(),
110 Self::type_spec_to_concrete(
111 &arg.ty.inner,
112 type_list,
113 &generic_subs,
114 false,
115 ),
116 )
117 })
118 .collect();
119 (name.inner.clone(), args)
120 })
121 .collect();
122 ConcreteType::Enum { options }
123 }
124 hir::TypeDeclKind::Struct(s) => {
125 let members = s
126 .members
127 .0
128 .iter()
129 .map(
130 |Parameter {
131 name: ident,
132 ty: t,
133 no_mangle: _,
134 field_translator: _,
135 }| {
136 (
137 ident.inner.clone(),
138 Self::type_spec_to_concrete(t, type_list, &generic_subs, invert),
139 )
140 },
141 )
142 .collect();
143
144 let translators = s.members.0.iter().filter_map(
145 |Parameter {
146 name,
147 field_translator,
148 ..
149 }| {
150 field_translator
151 .as_ref()
152 .map(|t| (name.inner.clone(), t.clone()))
153 },
154 );
155
156 ConcreteType::Struct {
157 name: decl.name.inner.clone(),
158 is_port: s.is_port,
159 members,
160 field_translators: translators.collect(),
161 }
162 }
163 hir::TypeDeclKind::Primitive(primitive) => ConcreteType::Single {
164 base: primitive.clone(),
165 params,
166 },
167 }
168 }
169
170 pub fn type_expr_to_concrete(
171 expr: &TypeExpression,
172 type_list: &TypeList,
173 generic_substitutions: &HashMap<NameID, &ConcreteType>,
174 invert: bool,
175 ) -> ConcreteType {
176 match &expr {
177 hir::TypeExpression::Integer(val) => ConcreteType::Integer(val.clone()),
178 hir::TypeExpression::TypeSpec(inner) => {
179 Self::type_spec_to_concrete(inner, type_list, generic_substitutions, invert)
180 }
181 hir::TypeExpression::ConstGeneric(_) => {
182 unreachable!("Const generic in type_expr_to_concrete")
183 }
184 }
185 }
186
187 pub fn type_spec_to_concrete(
188 spec: &TypeSpec,
189 type_list: &TypeList,
190 generic_substitutions: &HashMap<NameID, &ConcreteType>,
191 invert: bool,
192 ) -> ConcreteType {
193 match spec {
194 TypeSpec::Declared(name, params) => {
195 let params = params
196 .iter()
197 .map(|p| {
198 Self::type_expr_to_concrete(p, type_list, generic_substitutions, invert)
199 })
200 .collect();
201
202 let actual = type_list
203 .get(name)
204 .unwrap_or_else(|| panic!("Expected {:?} to be in type list", name));
205
206 Self::type_decl_to_concrete(actual, type_list, params, invert)
207 }
208 TypeSpec::Generic(name) => {
209 (*generic_substitutions
211 .get(name)
212 .unwrap_or_else(|| panic!("Expected a substitution for {}", name)))
213 .clone()
214 }
215 TypeSpec::Tuple(t) => {
216 let inner = t
217 .iter()
218 .map(|v| {
219 Self::type_spec_to_concrete(
220 &v.inner,
221 type_list,
222 generic_substitutions,
223 invert,
224 )
225 })
226 .collect::<Vec<_>>();
227 ConcreteType::Tuple(inner)
228 }
229 TypeSpec::Array { inner, size } => {
230 let size_type = Box::new(Self::type_expr_to_concrete(
231 size,
232 type_list,
233 generic_substitutions,
234 invert,
235 ));
236
237 let size = if let ConcreteType::Integer(size) = size_type.as_ref() {
238 size.clone()
239 } else {
240 panic!("Array size must be an integer")
241 };
242
243 ConcreteType::Array {
244 inner: Box::new(Self::type_spec_to_concrete(
245 inner,
246 type_list,
247 generic_substitutions,
248 invert,
249 )),
250 size,
251 }
252 }
253 TypeSpec::Wire(inner) => {
254 let inner = Box::new(Self::type_spec_to_concrete(
255 inner,
256 type_list,
257 generic_substitutions,
258 invert,
259 ));
260 if invert {
261 ConcreteType::Backward(inner)
262 } else {
263 ConcreteType::Wire(inner)
264 }
265 }
266 TypeSpec::Inverted(inner) => Self::type_spec_to_concrete(
267 inner,
268 type_list,
269 generic_substitutions,
270 !invert,
273 ),
274 TypeSpec::TraitSelf(_) => panic!("Trying to concretize HIR TraitSelf type"),
275 TypeSpec::Wildcard(_) => panic!("Trying to concretize HIR Wildcard type"),
276 }
277 }
278
279 pub fn inner_ungenerify_type(
280 &self,
281 var: &TypeVarID,
282 symtab: &SymbolTable,
283 type_list: &TypeList,
284 invert: bool,
285 ) -> Option<ConcreteType> {
286 match var.resolve(self) {
287 TypeVar::Known(_, KnownType::Error, _) => Some(ConcreteType::Error),
288 TypeVar::Known(_, KnownType::Named(t), params) => {
289 let params = params
290 .iter()
291 .map(|v| self.inner_ungenerify_type(v, symtab, type_list, invert))
292 .collect::<Option<Vec<_>>>()?;
293
294 type_list
295 .get(t)
296 .map(|t| Self::type_decl_to_concrete(&t.inner, type_list, params, invert))
297 }
298 TypeVar::Known(_, KnownType::Integer(val), params) => {
299 assert!(params.is_empty(), "integers cannot have type parameters");
300
301 Some(ConcreteType::Integer(val.clone()))
302 }
303 TypeVar::Known(_, KnownType::Bool(val), params) => {
304 assert!(
305 params.is_empty(),
306 "type level bools cannot have type parameters"
307 );
308
309 Some(ConcreteType::Bool(*val))
310 }
311 TypeVar::Known(_, KnownType::Array, inner) => {
312 let value = self.inner_ungenerify_type(&inner[0], symtab, type_list, invert);
313 let size = self.ungenerify_type(&inner[1], symtab, type_list).map(|t| {
314 if let ConcreteType::Integer(size) = t {
315 size
316 } else {
317 panic!("Array size must be an integer")
318 }
319 });
320
321 match (value, size) {
322 (Some(value), Some(size)) => Some(ConcreteType::Array {
323 inner: Box::new(value),
324 size,
325 }),
326 _ => None,
327 }
328 }
329 TypeVar::Known(_, KnownType::Tuple, inner) => {
330 let inner = inner
331 .iter()
332 .map(|v| self.inner_ungenerify_type(v, symtab, type_list, invert))
333 .collect::<Option<Vec<_>>>()?;
334 Some(ConcreteType::Tuple(inner))
335 }
336 TypeVar::Known(_, KnownType::Wire, inner) => {
337 if invert {
338 self.inner_ungenerify_type(&inner[0], symtab, type_list, invert)
339 .map(|t| ConcreteType::Backward(Box::new(t)))
340 } else {
341 self.inner_ungenerify_type(&inner[0], symtab, type_list, invert)
342 .map(|t| ConcreteType::Wire(Box::new(t)))
343 }
344 }
345 TypeVar::Known(_, KnownType::Inverted, inner) => {
346 self.inner_ungenerify_type(&inner[0], symtab, type_list, !invert)
347 }
348 TypeVar::Unknown(_, _, _, _) => None,
349 }
350 }
351
352 pub fn ungenerify_type(
355 &self,
356 var: &TypeVarID,
357 symtab: &SymbolTable,
358 type_list: &TypeList,
359 ) -> Option<ConcreteType> {
360 self.inner_ungenerify_type(var, symtab, type_list, false)
361 }
362
363 pub fn concrete_type_of_infallible(
366 &self,
367 id: ExprID,
368 symtab: &SymbolTable,
369 type_list: &TypeList,
370 ) -> ConcreteType {
371 self.concrete_type_of(id.nowhere(), symtab, type_list)
372 .expect("Expr had generic type")
373 }
374
375 pub fn concrete_type_of(
378 &self,
379 id: impl HasConcreteType,
380 symtab: &SymbolTable,
381 types: &TypeList,
382 ) -> Result<ConcreteType, Diagnostic> {
383 let id = id.into_typed_expression();
384 let t = self.type_of(&id.inner);
385
386 if let Some(t) = self.ungenerify_type(&t, symtab, types) {
387 Ok(t)
388 } else {
389 if std::env::var("SPADE_TRACE_TYPEINFERENCE").is_ok() {
390 println!("The incomplete type is {}", t.debug_resolve(self))
391 }
392 Err(
393 Diagnostic::error(id, "Type of expression is not fully known")
394 .primary_label("The type of this expression is not fully known")
395 .note(format!("Found incomplete type: {t}", t = t.display(self))),
396 )
397 }
398 }
399
400 pub fn concrete_type_of_name(
403 &self,
404 name: &Loc<NameID>,
405 symtab: &SymbolTable,
406 types: &TypeList,
407 ) -> Result<ConcreteType, Diagnostic> {
408 let t = self.type_of(&TypedExpression::Name(name.inner.clone()));
409
410 if let Some(t) = self.ungenerify_type(&t, symtab, types) {
411 Ok(t)
412 } else {
413 Err(
414 Diagnostic::error(name, format!("Type of {name} is not fully known"))
415 .primary_label(format!("The type of {name} is not fully known"))
416 .note(format!("Found incomplete type: {t}", t = t.display(self))),
417 )
418 }
419 }
420}