1use std::fmt;
18use std::marker::PhantomData;
19use std::sync::{Arc, Mutex, MutexGuard};
20
21use ghost_cell::GhostToken;
22
23use crate::dag::{Dag, DagLike};
24
25use super::{
26 Bound, CompleteBound, Error, Final, Incomplete, Type, TypeInner, UbElement, WithGhostToken,
27};
28
29type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
34
35#[derive(Clone)]
45pub struct Context<'brand> {
46 inner: Arc<Mutex<WithGhostToken<'brand, ContextInner<'brand>>>>,
47}
48
49struct ContextInner<'brand> {
50 slab: Vec<Bound<'brand>>,
51}
52
53impl fmt::Debug for Context<'_> {
54 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
55 let id = Arc::as_ptr(&self.inner) as usize;
56 write!(f, "inference_ctx_{:08x}", id)
57 }
58}
59
60impl PartialEq for Context<'_> {
61 fn eq(&self, other: &Self) -> bool {
62 Arc::ptr_eq(&self.inner, &other.inner)
63 }
64}
65impl Eq for Context<'_> {}
66
67impl<'brand> Context<'brand> {
68 pub fn with_context<R, F>(fun: F) -> R
70 where
71 F: for<'new_brand> FnOnce(Context<'new_brand>) -> R,
72 {
73 GhostToken::new(|token| {
74 let ctx = Context::new(token);
75 fun(ctx)
76 })
77 }
78
79 pub fn new(token: GhostToken<'brand>) -> Self {
81 Context {
82 inner: Arc::new(Mutex::new(WithGhostToken {
83 token,
84 inner: ContextInner { slab: vec![] },
85 })),
86 }
87 }
88
89 fn alloc_bound(&self, bound: Bound<'brand>) -> BoundRef<'brand> {
91 let mut lock = self.lock();
92 lock.alloc_bound(bound)
93 }
94
95 pub fn alloc_free(&self, name: String) -> BoundRef<'brand> {
97 self.alloc_bound(Bound::Free(name))
98 }
99
100 pub fn alloc_unit(&self) -> BoundRef<'brand> {
102 self.alloc_bound(Bound::Complete(Final::unit()))
103 }
104
105 pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef<'brand> {
107 self.alloc_bound(Bound::Complete(data))
108 }
109
110 pub fn alloc_sum(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
112 let mut lock = self.lock();
113 if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
114 lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
115 } else {
116 lock.alloc_bound(Bound::Sum(left.inner, right.inner))
117 }
118 }
119
120 pub fn alloc_product(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
122 let mut lock = self.lock();
123 if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
124 lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
125 } else {
126 lock.alloc_bound(Bound::Product(left.inner, right.inner))
127 }
128 }
129
130 pub fn shallow_clone(&self) -> Self {
136 Self {
137 inner: Arc::clone(&self.inner),
138 }
139 }
140
141 pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
143 if self == other {
144 Ok(())
145 } else {
146 Err(super::Error::InferenceContextMismatch)
147 }
148 }
149
150 pub(super) fn get(&self, bound: &BoundRef<'brand>) -> Bound<'brand> {
152 let lock = self.lock();
153 lock.inner.slab[bound.index].shallow_clone()
154 }
155
156 pub(super) fn get_root_ref(
158 &self,
159 bound: &UbElement<'brand, BoundRef<'brand>>,
160 ) -> BoundRef<'brand> {
161 let mut lock = self.lock();
162 bound.root(&mut lock.token)
163 }
164
165 pub(super) fn reassign_non_complete(&self, bound: BoundRef<'brand>, new: Bound<'brand>) {
176 let mut lock = self.lock();
177 lock.reassign_non_complete(bound, new);
178 }
179
180 pub fn bind_product(
185 &self,
186 existing: &Type<'brand>,
187 prod_l: &Type<'brand>,
188 prod_r: &Type<'brand>,
189 hint: &'static str,
190 ) -> Result<(), Error> {
191 let mut lock = self.lock();
192 let existing_root = existing.inner.bound.root(&mut lock.token);
193 let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
194
195 lock.bind(existing_root, new_bound).map_err(|e| {
196 let new_bound = lock.alloc_bound(e.new);
197 drop(lock);
198 Error::Bind {
199 existing_bound: Incomplete::from_bound_ref(self, e.existing),
200 new_bound: Incomplete::from_bound_ref(self, new_bound),
201 hint,
202 }
203 })
204 }
205
206 pub fn unify(
210 &self,
211 ty1: &Type<'brand>,
212 ty2: &Type<'brand>,
213 hint: &'static str,
214 ) -> Result<(), Error> {
215 let mut lock = self.lock();
216 lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
217 let new_bound = lock.alloc_bound(e.new);
218 drop(lock);
219 Error::Bind {
220 existing_bound: Incomplete::from_bound_ref(self, e.existing),
221 new_bound: Incomplete::from_bound_ref(self, new_bound),
222 hint,
223 }
224 })
225 }
226
227 fn lock(&self) -> MutexGuard<'_, WithGhostToken<'brand, ContextInner<'brand>>> {
229 self.inner.lock().unwrap()
230 }
231}
232
233#[derive(Debug, Clone)]
234pub struct BoundRef<'brand> {
235 phantom: InvariantLifetime<'brand>,
236 index: usize,
237}
238
239impl<'brand> BoundRef<'brand> {
240 pub fn occurs_check_id(&self) -> OccursCheckId<'brand> {
244 OccursCheckId {
245 phantom: InvariantLifetime::default(),
246 index: self.index,
247 }
248 }
249}
250
251impl super::PointerLike for BoundRef<'_> {
252 fn ptr_eq(&self, other: &Self) -> bool {
253 self.index == other.index
254 }
255
256 fn shallow_clone(&self) -> Self {
257 BoundRef {
258 phantom: InvariantLifetime::default(),
259 index: self.index,
260 }
261 }
262}
263
264impl<'brand> DagLike for (&'_ Context<'brand>, BoundRef<'brand>) {
265 type Node = BoundRef<'brand>;
266 fn data(&self) -> &BoundRef<'brand> {
267 &self.1
268 }
269
270 fn as_dag_node(&self) -> Dag<Self> {
271 match self.0.get(&self.1) {
272 Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
273 Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
274 let root1 = self.0.get_root_ref(&ty1.bound);
275 let root2 = self.0.get_root_ref(&ty2.bound);
276 Dag::Binary((self.0, root1), (self.0, root2))
277 }
278 }
279 }
280}
281
282#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
283pub struct OccursCheckId<'brand> {
284 phantom: InvariantLifetime<'brand>,
285 index: usize,
286}
287
288struct BindError<'brand> {
289 existing: BoundRef<'brand>,
290 new: Bound<'brand>,
291}
292
293impl<'brand> ContextInner<'brand> {
294 fn alloc_bound(&mut self, bound: Bound<'brand>) -> BoundRef<'brand> {
295 self.slab.push(bound);
296 let index = self.slab.len() - 1;
297
298 BoundRef {
299 phantom: InvariantLifetime::default(),
300 index,
301 }
302 }
303
304 fn reassign_non_complete(&mut self, bound: BoundRef<'brand>, new: Bound<'brand>) {
305 assert!(
306 !matches!(self.slab[bound.index], Bound::Complete(..)),
307 "tried to modify finalized type",
308 );
309 self.slab[bound.index] = new;
310 }
311}
312
313impl<'brand> WithGhostToken<'brand, ContextInner<'brand>> {
314 fn complete_pair_data(
320 &mut self,
321 inn1: &TypeInner<'brand>,
322 inn2: &TypeInner<'brand>,
323 ) -> Option<(Arc<Final>, Arc<Final>)> {
324 let idx1 = inn1.bound.root(&mut self.token).index;
325 let idx2 = inn2.bound.root(&mut self.token).index;
326 let bound1 = &self.slab[idx1];
327 let bound2 = &self.slab[idx2];
328 if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
329 Some((Arc::clone(data1), Arc::clone(data2)))
330 } else {
331 None
332 }
333 }
334 fn unify(
338 &mut self,
339 existing: &TypeInner<'brand>,
340 other: &TypeInner<'brand>,
341 ) -> Result<(), BindError<'brand>> {
342 existing
343 .bound
344 .unify(self, &other.bound, |self_, x_bound, y_bound| {
345 self_.bind(x_bound, self_.slab[y_bound.index].shallow_clone())
346 })
347 }
348
349 fn bind(
350 &mut self,
351 existing: BoundRef<'brand>,
352 new: Bound<'brand>,
353 ) -> Result<(), BindError<'brand>> {
354 let existing_bound = self.slab[existing.index].shallow_clone();
355 let bind_error = || BindError {
356 existing: existing.clone(),
357 new: new.shallow_clone(),
358 };
359
360 match (&existing_bound, &new) {
361 (_, Bound::Free(_)) => Ok(()),
363 (Bound::Free(_), _) => {
365 self.reassign_non_complete(existing, new);
367 Ok(())
368 }
369 (Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
372 if existing_final == new_final {
373 Ok(())
374 } else {
375 Err(bind_error())
376 }
377 }
378 (Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
380 match (complete.bound(), incomplete) {
381 (CompleteBound::Unit, _) => Err(bind_error()),
384 (
385 CompleteBound::Product(ref comp1, ref comp2),
386 Bound::Product(ref ty1, ref ty2),
387 )
388 | (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
389 let bound1 = ty1.bound.root(&mut self.token);
390 let bound2 = ty2.bound.root(&mut self.token);
391 self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
392 self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
393 }
394 _ => Err(bind_error()),
395 }
396 }
397 (Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
398 | (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
399 self.unify(x1, y1)?;
400 self.unify(x2, y2)?;
401 if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
409 self.reassign_non_complete(
410 existing,
411 Bound::Complete(if let Bound::Sum(..) = existing_bound {
412 Final::sum(data1, data2)
413 } else {
414 Final::product(data1, data2)
415 }),
416 );
417 }
418 Ok(())
419 }
420 (_, _) => Err(bind_error()),
421 }
422 }
423}