1use std::any::TypeId;
18use std::fmt;
19use std::marker::PhantomData;
20use std::sync::{Arc, Mutex, MutexGuard};
21
22use ghost_cell::GhostToken;
23
24use crate::dag::{Dag, DagLike};
25use crate::jet::Jet;
26
27use super::{
28 Bound, CompleteBound, Error, Final, Incomplete, Type, TypeInner, UbElement, WithGhostToken,
29};
30
31type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
36
37#[derive(Clone)]
47pub struct Context<'brand> {
48 inner: Arc<Mutex<WithGhostToken<'brand, ContextInner<'brand>>>>,
49}
50
51struct ContextInner<'brand> {
52 slab: Vec<Bound<'brand>>,
53 jet_type: Option<TypeId>,
55}
56
57impl fmt::Debug for Context<'_> {
58 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59 let id = Arc::as_ptr(&self.inner) as usize;
60 write!(f, "inference_ctx_{:08x}", id)
61 }
62}
63
64impl PartialEq for Context<'_> {
65 fn eq(&self, other: &Self) -> bool {
66 Arc::ptr_eq(&self.inner, &other.inner)
67 }
68}
69impl Eq for Context<'_> {}
70
71impl<'brand> Context<'brand> {
72 pub fn with_context<R, F>(fun: F) -> R
74 where
75 F: for<'new_brand> FnOnce(Context<'new_brand>) -> R,
76 {
77 GhostToken::new(|token| {
78 let ctx = Context::new(token);
79 fun(ctx)
80 })
81 }
82
83 pub fn new(token: GhostToken<'brand>) -> Self {
85 Context {
86 inner: Arc::new(Mutex::new(WithGhostToken {
87 token,
88 inner: ContextInner {
89 slab: vec![],
90 jet_type: None,
91 },
92 })),
93 }
94 }
95
96 fn alloc_bound(&self, bound: Bound<'brand>) -> BoundRef<'brand> {
98 let mut lock = self.lock();
99 lock.alloc_bound(bound)
100 }
101
102 pub fn alloc_free(&self, name: String) -> BoundRef<'brand> {
104 self.alloc_bound(Bound::Free(name))
105 }
106
107 pub fn alloc_unit(&self) -> BoundRef<'brand> {
109 self.alloc_bound(Bound::Complete(Final::unit()))
110 }
111
112 pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef<'brand> {
114 self.alloc_bound(Bound::Complete(data))
115 }
116
117 pub fn alloc_sum(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
119 let mut lock = self.lock();
120 if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
121 lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
122 } else {
123 lock.alloc_bound(Bound::Sum(left.inner, right.inner))
124 }
125 }
126
127 pub fn alloc_product(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
129 let mut lock = self.lock();
130 if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
131 lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
132 } else {
133 lock.alloc_bound(Bound::Product(left.inner, right.inner))
134 }
135 }
136
137 pub fn shallow_clone(&self) -> Self {
143 Self {
144 inner: Arc::clone(&self.inner),
145 }
146 }
147
148 pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
150 if self == other {
151 Ok(())
152 } else {
153 Err(super::Error::InferenceContextMismatch)
154 }
155 }
156
157 pub fn check_jet(&self, jet: &dyn Jet) {
162 let new_id = jet.as_any().type_id();
163 let mut lock = self.lock();
164
165 if let Some(existing_id) = lock.inner.jet_type {
166 assert!(existing_id == new_id, "mixed jet types in context");
167
168 return;
169 }
170
171 lock.inner.jet_type = Some(new_id);
172 }
173
174 pub(super) fn get(&self, bound: &BoundRef<'brand>) -> Bound<'brand> {
176 let lock = self.lock();
177 lock.inner.slab[bound.index].shallow_clone()
178 }
179
180 pub(super) fn get_root_ref(
182 &self,
183 bound: &UbElement<'brand, BoundRef<'brand>>,
184 ) -> BoundRef<'brand> {
185 let mut lock = self.lock();
186 bound.root(&mut lock.token)
187 }
188
189 pub(super) fn reassign_non_complete(&self, bound: BoundRef<'brand>, new: Bound<'brand>) {
200 let mut lock = self.lock();
201 lock.reassign_non_complete(bound, new);
202 }
203
204 pub fn bind_product(
209 &self,
210 existing: &Type<'brand>,
211 prod_l: &Type<'brand>,
212 prod_r: &Type<'brand>,
213 hint: &'static str,
214 ) -> Result<(), Error> {
215 let mut lock = self.lock();
216 let existing_root = existing.inner.bound.root(&mut lock.token);
217 let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
218
219 lock.bind(existing_root, new_bound).map_err(|e| {
220 let new_bound = lock.alloc_bound(e.new);
221 drop(lock);
222 Error::Bind {
223 existing_bound: Incomplete::from_bound_ref(self, e.existing),
224 new_bound: Incomplete::from_bound_ref(self, new_bound),
225 hint,
226 }
227 })
228 }
229
230 pub fn unify(
234 &self,
235 ty1: &Type<'brand>,
236 ty2: &Type<'brand>,
237 hint: &'static str,
238 ) -> Result<(), Error> {
239 let mut lock = self.lock();
240 lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
241 let new_bound = lock.alloc_bound(e.new);
242 drop(lock);
243 Error::Bind {
244 existing_bound: Incomplete::from_bound_ref(self, e.existing),
245 new_bound: Incomplete::from_bound_ref(self, new_bound),
246 hint,
247 }
248 })
249 }
250
251 fn lock(&self) -> MutexGuard<'_, WithGhostToken<'brand, ContextInner<'brand>>> {
253 self.inner.lock().unwrap()
254 }
255}
256
257#[derive(Debug, Clone)]
258pub struct BoundRef<'brand> {
259 phantom: InvariantLifetime<'brand>,
260 index: usize,
261}
262
263impl<'brand> BoundRef<'brand> {
264 pub fn occurs_check_id(&self) -> OccursCheckId<'brand> {
268 OccursCheckId {
269 phantom: InvariantLifetime::default(),
270 index: self.index,
271 }
272 }
273}
274
275impl super::PointerLike for BoundRef<'_> {
276 fn ptr_eq(&self, other: &Self) -> bool {
277 self.index == other.index
278 }
279
280 fn shallow_clone(&self) -> Self {
281 BoundRef {
282 phantom: InvariantLifetime::default(),
283 index: self.index,
284 }
285 }
286}
287
288impl<'brand> DagLike for (&'_ Context<'brand>, BoundRef<'brand>) {
289 type Node = BoundRef<'brand>;
290 fn data(&self) -> &BoundRef<'brand> {
291 &self.1
292 }
293
294 fn as_dag_node(&self) -> Dag<Self> {
295 match self.0.get(&self.1) {
296 Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
297 Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
298 let root1 = self.0.get_root_ref(&ty1.bound);
299 let root2 = self.0.get_root_ref(&ty2.bound);
300 Dag::Binary((self.0, root1), (self.0, root2))
301 }
302 }
303 }
304}
305
306#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
307pub struct OccursCheckId<'brand> {
308 phantom: InvariantLifetime<'brand>,
309 index: usize,
310}
311
312struct BindError<'brand> {
313 existing: BoundRef<'brand>,
314 new: Bound<'brand>,
315}
316
317impl<'brand> ContextInner<'brand> {
318 fn alloc_bound(&mut self, bound: Bound<'brand>) -> BoundRef<'brand> {
319 self.slab.push(bound);
320 let index = self.slab.len() - 1;
321
322 BoundRef {
323 phantom: InvariantLifetime::default(),
324 index,
325 }
326 }
327
328 fn reassign_non_complete(&mut self, bound: BoundRef<'brand>, new: Bound<'brand>) {
329 assert!(
330 !matches!(self.slab[bound.index], Bound::Complete(..)),
331 "tried to modify finalized type",
332 );
333 self.slab[bound.index] = new;
334 }
335}
336
337impl<'brand> WithGhostToken<'brand, ContextInner<'brand>> {
338 fn complete_pair_data(
344 &mut self,
345 inn1: &TypeInner<'brand>,
346 inn2: &TypeInner<'brand>,
347 ) -> Option<(Arc<Final>, Arc<Final>)> {
348 let idx1 = inn1.bound.root(&mut self.token).index;
349 let idx2 = inn2.bound.root(&mut self.token).index;
350 let bound1 = &self.slab[idx1];
351 let bound2 = &self.slab[idx2];
352 if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
353 Some((Arc::clone(data1), Arc::clone(data2)))
354 } else {
355 None
356 }
357 }
358 fn unify(
362 &mut self,
363 existing: &TypeInner<'brand>,
364 other: &TypeInner<'brand>,
365 ) -> Result<(), BindError<'brand>> {
366 existing
367 .bound
368 .unify(self, &other.bound, |self_, x_bound, y_bound| {
369 self_.bind(x_bound, self_.slab[y_bound.index].shallow_clone())
370 })
371 }
372
373 fn bind(
374 &mut self,
375 existing: BoundRef<'brand>,
376 new: Bound<'brand>,
377 ) -> Result<(), BindError<'brand>> {
378 let existing_bound = self.slab[existing.index].shallow_clone();
379 let bind_error = || BindError {
380 existing: existing.clone(),
381 new: new.shallow_clone(),
382 };
383
384 match (&existing_bound, &new) {
385 (_, Bound::Free(_)) => Ok(()),
387 (Bound::Free(_), _) => {
389 self.reassign_non_complete(existing, new);
391 Ok(())
392 }
393 (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
396 if existing_final == new_final {
397 Ok(())
398 } else {
399 Err(bind_error())
400 }
401 }
402 (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
404 match (complete.bound(), incomplete) {
405 (CompleteBound::Unit, _) => Err(bind_error()),
408 (
409 CompleteBound::Product(ref comp1, ref comp2),
410 Bound::Product(ref ty1, ref ty2),
411 )
412 | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
413 let bound1 = ty1.bound.root(&mut self.token);
414 let bound2 = ty2.bound.root(&mut self.token);
415 self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
416 self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
417 }
418 _ => Err(bind_error()),
419 }
420 }
421 (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
422 | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
423 self.unify(x1, y1)?;
424 self.unify(x2, y2)?;
425 if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
433 self.reassign_non_complete(
434 existing,
435 Bound::Complete(if let Bound::Sum(..) = existing_bound {
436 Final::sum(data1, data2)
437 } else {
438 Final::product(data1, data2)
439 }),
440 );
441 }
442 Ok(())
443 }
444 (_, _) => Err(bind_error()),
445 }
446 }
447}