1use std::fmt;
18use std::sync::{Arc, Mutex, MutexGuard};
19
20use crate::dag::{Dag, DagLike};
21
22use super::{Bound, CompleteBound, Error, Final, Type, TypeInner};
23
24#[derive(Clone, Default)]
34pub struct Context {
35 slab: Arc<Mutex<Vec<Bound>>>,
36}
37
38impl fmt::Debug for Context {
39 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40 let id = Arc::as_ptr(&self.slab) as usize;
41 write!(f, "inference_ctx_{:08x}", id)
42 }
43}
44
45impl PartialEq for Context {
46 fn eq(&self, other: &Self) -> bool {
47 Arc::ptr_eq(&self.slab, &other.slab)
48 }
49}
50impl Eq for Context {}
51
52impl Context {
53 pub fn new() -> Self {
55 Context {
56 slab: Arc::new(Mutex::new(vec![])),
57 }
58 }
59
60 fn alloc_bound(&self, bound: Bound) -> BoundRef {
62 let mut lock = self.lock();
63 lock.alloc_bound(bound)
64 }
65
66 pub fn alloc_free(&self, name: String) -> BoundRef {
68 self.alloc_bound(Bound::Free(name))
69 }
70
71 pub fn alloc_unit(&self) -> BoundRef {
73 self.alloc_bound(Bound::Complete(Final::unit()))
74 }
75
76 pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef {
78 self.alloc_bound(Bound::Complete(data))
79 }
80
81 pub fn alloc_sum(&self, left: Type, right: Type) -> BoundRef {
87 assert_eq!(
88 left.ctx, *self,
89 "left type did not match inference context of sum"
90 );
91 assert_eq!(
92 right.ctx, *self,
93 "right type did not match inference context of sum"
94 );
95
96 let mut lock = self.lock();
97 if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
98 lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
99 } else {
100 lock.alloc_bound(Bound::Sum(left.inner, right.inner))
101 }
102 }
103
104 pub fn alloc_product(&self, left: Type, right: Type) -> BoundRef {
110 assert_eq!(
111 left.ctx, *self,
112 "left type did not match inference context of product"
113 );
114 assert_eq!(
115 right.ctx, *self,
116 "right type did not match inference context of product"
117 );
118
119 let mut lock = self.lock();
120 if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
121 lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
122 } else {
123 lock.alloc_bound(Bound::Product(left.inner, right.inner))
124 }
125 }
126
127 pub fn shallow_clone(&self) -> Self {
133 Self {
134 slab: Arc::clone(&self.slab),
135 }
136 }
137
138 pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
140 if self == other {
141 Ok(())
142 } else {
143 Err(super::Error::InferenceContextMismatch)
144 }
145 }
146
147 pub(super) fn get(&self, bound: &BoundRef) -> Bound {
153 bound.assert_matches_context(self);
154 let lock = self.lock();
155 lock.slab[bound.index].shallow_clone()
156 }
157
158 pub(super) fn reassign_non_complete(&self, bound: BoundRef, new: Bound) {
169 let mut lock = self.lock();
170 lock.reassign_non_complete(bound, new);
171 }
172
173 pub fn bind_product(
183 &self,
184 existing: &Type,
185 prod_l: &Type,
186 prod_r: &Type,
187 hint: &'static str,
188 ) -> Result<(), Error> {
189 assert_eq!(
190 existing.ctx, *self,
191 "attempted to bind existing type with wrong context",
192 );
193 assert_eq!(
194 prod_l.ctx, *self,
195 "attempted to bind product whose left type had wrong context",
196 );
197 assert_eq!(
198 prod_r.ctx, *self,
199 "attempted to bind product whose right type had wrong context",
200 );
201
202 let existing_root = existing.inner.bound.root();
203 let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
204
205 let mut lock = self.lock();
206 lock.bind(existing_root, new_bound).map_err(|e| {
207 let new_bound = lock.alloc_bound(e.new);
208 Error::Bind {
209 existing_bound: Type::wrap_bound(self, e.existing),
210 new_bound: Type::wrap_bound(self, new_bound),
211 hint,
212 }
213 })
214 }
215
216 pub fn unify(&self, ty1: &Type, ty2: &Type, hint: &'static str) -> Result<(), Error> {
220 assert_eq!(ty1.ctx, *self);
221 assert_eq!(ty2.ctx, *self);
222 let mut lock = self.lock();
223 lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
224 let new_bound = lock.alloc_bound(e.new);
225 Error::Bind {
226 existing_bound: Type::wrap_bound(self, e.existing),
227 new_bound: Type::wrap_bound(self, new_bound),
228 hint,
229 }
230 })
231 }
232
233 fn lock(&self) -> LockedContext {
235 LockedContext {
236 context: Arc::as_ptr(&self.slab),
237 slab: self.slab.lock().unwrap(),
238 }
239 }
240}
241
242#[derive(Debug, Clone)]
243pub struct BoundRef {
244 context: *const Mutex<Vec<Bound>>,
245 index: usize,
246}
247
248unsafe impl Send for BoundRef {}
258unsafe impl Sync for BoundRef {}
260
261impl BoundRef {
262 pub fn assert_matches_context(&self, ctx: &Context) {
263 assert_eq!(
264 self.context,
265 Arc::as_ptr(&ctx.slab),
266 "bound was accessed from a type inference context that did not create it",
267 );
268 }
269
270 pub fn occurs_check_id(&self) -> OccursCheckId {
274 OccursCheckId {
275 context: self.context,
276 index: self.index,
277 }
278 }
279}
280
281impl super::PointerLike for BoundRef {
282 fn ptr_eq(&self, other: &Self) -> bool {
283 debug_assert_eq!(
284 self.context, other.context,
285 "tried to compare two bounds from different inference contexts"
286 );
287 self.index == other.index
288 }
289
290 fn shallow_clone(&self) -> Self {
291 BoundRef {
292 context: self.context,
293 index: self.index,
294 }
295 }
296}
297
298impl DagLike for (&'_ Context, BoundRef) {
299 type Node = BoundRef;
300 fn data(&self) -> &BoundRef {
301 &self.1
302 }
303
304 fn as_dag_node(&self) -> Dag<Self> {
305 match self.0.get(&self.1) {
306 Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
307 Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
308 Dag::Binary((self.0, ty1.bound.root()), (self.0, ty2.bound.root()))
309 }
310 }
311 }
312}
313
314#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
315pub struct OccursCheckId {
316 context: *const Mutex<Vec<Bound>>,
317 index: usize,
318}
319
320struct BindError {
321 existing: BoundRef,
322 new: Bound,
323}
324
325struct LockedContext<'ctx> {
330 context: *const Mutex<Vec<Bound>>,
331 slab: MutexGuard<'ctx, Vec<Bound>>,
332}
333
334impl LockedContext<'_> {
335 fn alloc_bound(&mut self, bound: Bound) -> BoundRef {
336 self.slab.push(bound);
337 let index = self.slab.len() - 1;
338
339 BoundRef {
340 context: self.context,
341 index,
342 }
343 }
344
345 fn reassign_non_complete(&mut self, bound: BoundRef, new: Bound) {
346 assert!(
347 !matches!(self.slab[bound.index], Bound::Complete(..)),
348 "tried to modify finalized type",
349 );
350 self.slab[bound.index] = new;
351 }
352
353 fn complete_pair_data(
359 &self,
360 inn1: &TypeInner,
361 inn2: &TypeInner,
362 ) -> Option<(Arc<Final>, Arc<Final>)> {
363 let bound1 = &self.slab[inn1.bound.root().index];
364 let bound2 = &self.slab[inn2.bound.root().index];
365 if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
366 Some((Arc::clone(data1), Arc::clone(data2)))
367 } else {
368 None
369 }
370 }
371
372 fn unify(&mut self, existing: &TypeInner, other: &TypeInner) -> Result<(), BindError> {
376 existing.bound.unify(&other.bound, |x_bound, y_bound| {
377 self.bind(x_bound, self.slab[y_bound.index].shallow_clone())
378 })
379 }
380
381 fn bind(&mut self, existing: BoundRef, new: Bound) -> Result<(), BindError> {
382 let existing_bound = self.slab[existing.index].shallow_clone();
383 let bind_error = || BindError {
384 existing: existing.clone(),
385 new: new.shallow_clone(),
386 };
387
388 match (&existing_bound, &new) {
389 (_, Bound::Free(_)) => Ok(()),
391 (Bound::Free(_), _) => {
393 self.reassign_non_complete(existing, new);
395 Ok(())
396 }
397 (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
400 if existing_final == new_final {
401 Ok(())
402 } else {
403 Err(bind_error())
404 }
405 }
406 (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
408 match (complete.bound(), incomplete) {
409 (CompleteBound::Unit, _) => Err(bind_error()),
412 (
413 CompleteBound::Product(ref comp1, ref comp2),
414 Bound::Product(ref ty1, ref ty2),
415 )
416 | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
417 let bound1 = ty1.bound.root();
418 let bound2 = ty2.bound.root();
419 self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
420 self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
421 }
422 _ => Err(bind_error()),
423 }
424 }
425 (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
426 | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
427 self.unify(x1, y1)?;
428 self.unify(x2, y2)?;
429 if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
437 self.reassign_non_complete(
438 existing,
439 Bound::Complete(if let Bound::Sum(..) = existing_bound {
440 Final::sum(data1, data2)
441 } else {
442 Final::product(data1, data2)
443 }),
444 );
445 }
446 Ok(())
447 }
448 (_, _) => Err(bind_error()),
449 }
450 }
451}