1use shape_ast::ast::Span;
10use shape_ast::error::{ErrorNote, ShapeError, SourceLocation};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct RegionId(pub u32);
16
17#[derive(Debug, Clone)]
19pub struct BorrowRecord {
20 pub borrowed_slot: u16,
22 pub is_exclusive: bool,
24 pub origin_region: RegionId,
26 pub borrow_region: RegionId,
28 pub ref_slot: u16,
30 pub span: Span,
32 pub source_location: Option<SourceLocation>,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum BorrowMode {
39 Shared,
40 Exclusive,
41}
42
43impl BorrowMode {
44 fn is_exclusive(self) -> bool {
45 matches!(self, Self::Exclusive)
46 }
47}
48
49pub struct BorrowChecker {
55 current_region: RegionId,
57 region_stack: Vec<RegionId>,
59 next_region_id: u32,
61 active_borrows: HashMap<u16, Vec<BorrowRecord>>,
63 exclusively_borrowed: HashSet<u16>,
65 shared_borrow_count: HashMap<u16, u32>,
67 ref_slots_by_region: HashMap<RegionId, Vec<u16>>,
69}
70
71impl BorrowChecker {
72 pub fn new() -> Self {
74 Self {
75 current_region: RegionId(0),
76 region_stack: vec![RegionId(0)],
77 next_region_id: 1,
78 active_borrows: HashMap::new(),
79 exclusively_borrowed: HashSet::new(),
80 shared_borrow_count: HashMap::new(),
81 ref_slots_by_region: HashMap::new(),
82 }
83 }
84
85 pub fn enter_region(&mut self) -> RegionId {
87 let region = RegionId(self.next_region_id);
88 self.next_region_id += 1;
89 self.region_stack.push(region);
90 self.current_region = region;
91 region
92 }
93
94 pub fn exit_region(&mut self) {
96 let exiting = self.current_region;
97
98 self.release_borrows_in_region(exiting);
100
101 self.region_stack.pop();
102 self.current_region = self.region_stack.last().copied().unwrap_or(RegionId(0));
103 }
104
105 pub fn current_region(&self) -> RegionId {
107 self.current_region
108 }
109
110 pub fn create_borrow(
118 &mut self,
119 slot: u16,
120 ref_slot: u16,
121 mode: BorrowMode,
122 span: Span,
123 source_location: Option<SourceLocation>,
124 ) -> Result<(), ShapeError> {
125 if mode.is_exclusive() {
126 if self.exclusively_borrowed.contains(&slot) {
128 return Err(self.make_borrow_conflict_error(
129 "B0001",
130 slot,
131 source_location,
132 "cannot mutably borrow this value because it is already borrowed",
133 "end the previous borrow before creating a mutable borrow, or use a shared borrow",
134 ));
135 }
136 if self.shared_borrow_count.get(&slot).copied().unwrap_or(0) > 0 {
137 return Err(self.make_borrow_conflict_error(
138 "B0001",
139 slot,
140 source_location,
141 "cannot mutably borrow this value while shared borrows are active",
142 "move the mutable borrow later, or make prior borrows immutable-only reads",
143 ));
144 }
145 self.exclusively_borrowed.insert(slot);
146 } else {
147 if self.exclusively_borrowed.contains(&slot) {
149 return Err(self.make_borrow_conflict_error(
150 "B0001",
151 slot,
152 source_location,
153 "cannot immutably borrow this value because it is mutably borrowed",
154 "drop the mutable borrow before taking an immutable borrow",
155 ));
156 }
157 *self.shared_borrow_count.entry(slot).or_insert(0) += 1;
158 }
159
160 let record = BorrowRecord {
161 borrowed_slot: slot,
162 is_exclusive: mode.is_exclusive(),
163 origin_region: self.current_region,
164 borrow_region: self.current_region,
165 ref_slot,
166 span,
167 source_location,
168 };
169
170 self.active_borrows.entry(slot).or_default().push(record);
171
172 self.ref_slots_by_region
173 .entry(self.current_region)
174 .or_default()
175 .push(slot);
176
177 Ok(())
178 }
179
180 pub fn check_write_allowed(
182 &self,
183 slot: u16,
184 source_location: Option<SourceLocation>,
185 ) -> Result<(), ShapeError> {
186 if let Some(borrows) = self.active_borrows.get(&slot) {
187 if !borrows.is_empty() {
188 return Err(self.make_borrow_conflict_error(
189 "B0002",
190 slot,
191 source_location,
192 "cannot write to this value while it is borrowed",
193 "move this write after the borrow ends",
194 ));
195 }
196 }
197 Ok(())
198 }
199
200 pub fn check_read_allowed(
204 &self,
205 slot: u16,
206 source_location: Option<SourceLocation>,
207 ) -> Result<(), ShapeError> {
208 if self.exclusively_borrowed.contains(&slot) {
209 return Err(self.make_borrow_conflict_error(
210 "B0001",
211 slot,
212 source_location,
213 "cannot read this value while it is mutably borrowed",
214 "read through the existing reference, or move the read after the borrow ends",
215 ));
216 }
217 Ok(())
218 }
219
220 pub fn check_no_escape(
223 &self,
224 ref_slot: u16,
225 source_location: Option<SourceLocation>,
226 ) -> Result<(), ShapeError> {
227 for borrows in self.active_borrows.values() {
229 for borrow in borrows {
230 if borrow.ref_slot == ref_slot {
231 let mut location = source_location;
232 if let Some(loc) = location.as_mut() {
233 loc.hints.push(
234 "keep references within the call/lexical scope where they were created"
235 .to_string(),
236 );
237 loc.notes.push(ErrorNote {
238 message: "borrow originates here".to_string(),
239 location: borrow.source_location.clone(),
240 });
241 }
242 return Err(ShapeError::SemanticError {
243 message: "[B0003] reference cannot escape its scope".to_string(),
244 location,
245 });
246 }
247 }
248 }
249 Ok(())
250 }
251
252 fn release_borrows_in_region(&mut self, region: RegionId) {
254 if let Some(slots) = self.ref_slots_by_region.remove(®ion) {
255 for slot in slots {
256 if let Some(borrows) = self.active_borrows.get_mut(&slot) {
257 borrows.retain(|b| b.borrow_region != region);
258
259 let has_exclusive = borrows.iter().any(|b| b.is_exclusive);
261 let shared_count = borrows.iter().filter(|b| !b.is_exclusive).count() as u32;
262
263 if !has_exclusive {
264 self.exclusively_borrowed.remove(&slot);
265 }
266 if shared_count == 0 {
267 self.shared_borrow_count.remove(&slot);
268 } else {
269 self.shared_borrow_count.insert(slot, shared_count);
270 }
271
272 if borrows.is_empty() {
273 self.active_borrows.remove(&slot);
274 }
275 }
276 }
277 }
278 }
279
280 pub fn reset(&mut self) {
282 self.current_region = RegionId(0);
283 self.region_stack = vec![RegionId(0)];
284 self.next_region_id = 1;
285 self.active_borrows.clear();
286 self.exclusively_borrowed.clear();
287 self.shared_borrow_count.clear();
288 self.ref_slots_by_region.clear();
289 }
290
291 fn first_conflicting_borrow(&self, slot: u16) -> Option<&BorrowRecord> {
292 self.active_borrows
293 .get(&slot)
294 .and_then(|borrows| borrows.first())
295 }
296
297 fn make_borrow_conflict_error(
298 &self,
299 code: &str,
300 slot: u16,
301 source_location: Option<SourceLocation>,
302 message: &str,
303 help: &str,
304 ) -> ShapeError {
305 let mut location = source_location;
306 if let Some(loc) = location.as_mut() {
307 loc.hints.push(help.to_string());
308 if let Some(conflict) = self.first_conflicting_borrow(slot) {
309 loc.notes.push(ErrorNote {
310 message: "first conflicting borrow occurs here".to_string(),
311 location: conflict.source_location.clone(),
312 });
313 }
314 }
315 ShapeError::SemanticError {
316 message: format!("[{}] {} (slot {})", code, message, slot),
317 location,
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn span() -> Span {
327 Span { start: 0, end: 1 }
328 }
329
330 #[test]
331 fn test_single_exclusive_borrow_ok() {
332 let mut bc = BorrowChecker::new();
333 assert!(
334 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
335 .is_ok()
336 );
337 }
338
339 #[test]
340 fn test_double_exclusive_borrow_rejected() {
341 let mut bc = BorrowChecker::new();
342 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
343 .unwrap();
344 let err = bc.create_borrow(0, 1, BorrowMode::Exclusive, span(), None);
345 assert!(err.is_err());
346 let msg = format!("{:?}", err.unwrap_err());
347 assert!(msg.contains("[B0001]"), "got: {}", msg);
348 }
349
350 #[test]
351 fn test_multiple_shared_borrows_ok() {
352 let mut bc = BorrowChecker::new();
353 assert!(
354 bc.create_borrow(0, 0, BorrowMode::Shared, span(), None)
355 .is_ok()
356 );
357 assert!(
358 bc.create_borrow(0, 1, BorrowMode::Shared, span(), None)
359 .is_ok()
360 );
361 assert!(
362 bc.create_borrow(0, 2, BorrowMode::Shared, span(), None)
363 .is_ok()
364 );
365 }
366
367 #[test]
368 fn test_exclusive_after_shared_rejected() {
369 let mut bc = BorrowChecker::new();
370 bc.create_borrow(0, 0, BorrowMode::Shared, span(), None)
371 .unwrap();
372 let err = bc.create_borrow(0, 1, BorrowMode::Exclusive, span(), None);
373 assert!(err.is_err());
374 let msg = format!("{:?}", err.unwrap_err());
375 assert!(msg.contains("[B0001]"), "got: {}", msg);
376 }
377
378 #[test]
379 fn test_shared_after_exclusive_rejected() {
380 let mut bc = BorrowChecker::new();
381 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
382 .unwrap();
383 let err = bc.create_borrow(0, 1, BorrowMode::Shared, span(), None);
384 assert!(err.is_err());
385 let msg = format!("{:?}", err.unwrap_err());
386 assert!(msg.contains("[B0001]"), "got: {}", msg);
387 }
388
389 #[test]
390 fn test_write_blocked_while_borrowed() {
391 let bc_shared = {
392 let mut bc = BorrowChecker::new();
393 bc.create_borrow(0, 0, BorrowMode::Shared, span(), None)
394 .unwrap();
395 bc
396 };
397 let err = bc_shared.check_write_allowed(0, None);
398 assert!(err.is_err());
399 let msg = format!("{:?}", err.unwrap_err());
400 assert!(msg.contains("[B0002]"), "got: {}", msg);
401 }
402
403 #[test]
404 fn test_write_allowed_when_no_borrows() {
405 let bc = BorrowChecker::new();
406 assert!(bc.check_write_allowed(0, None).is_ok());
407 }
408
409 #[test]
410 fn test_borrows_released_on_scope_exit() {
411 let mut bc = BorrowChecker::new();
412 bc.enter_region();
413 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
414 .unwrap();
415 assert!(bc.check_write_allowed(0, None).is_err());
417 bc.exit_region();
419 assert!(bc.check_write_allowed(0, None).is_ok());
420 assert!(
422 bc.create_borrow(0, 1, BorrowMode::Exclusive, span(), None)
423 .is_ok()
424 );
425 }
426
427 #[test]
428 fn test_nested_scopes() {
429 let mut bc = BorrowChecker::new();
430 bc.enter_region(); bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
432 .unwrap();
433 bc.enter_region(); bc.create_borrow(1, 1, BorrowMode::Exclusive, span(), None)
435 .unwrap();
436 assert!(bc.check_write_allowed(0, None).is_err());
438 bc.exit_region(); assert!(bc.check_write_allowed(1, None).is_ok());
440 assert!(bc.check_write_allowed(0, None).is_err());
442 bc.exit_region(); assert!(bc.check_write_allowed(0, None).is_ok());
444 }
445
446 #[test]
447 fn test_different_slots_independent() {
448 let mut bc = BorrowChecker::new();
449 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
450 .unwrap();
451 assert!(
453 bc.create_borrow(1, 1, BorrowMode::Exclusive, span(), None)
454 .is_ok()
455 );
456 assert!(bc.check_write_allowed(1, None).is_err());
457 assert!(bc.check_write_allowed(2, None).is_ok());
458 }
459
460 #[test]
461 fn test_check_no_escape() {
462 let mut bc = BorrowChecker::new();
463 bc.create_borrow(0, 5, BorrowMode::Exclusive, span(), None)
464 .unwrap();
465 assert!(bc.check_no_escape(5, None).is_err());
467 assert!(bc.check_no_escape(99, None).is_ok());
469 }
470
471 #[test]
472 fn test_reset_clears_all_state() {
473 let mut bc = BorrowChecker::new();
474 bc.enter_region();
475 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
476 .unwrap();
477 bc.reset();
478 assert!(bc.check_write_allowed(0, None).is_ok());
480 assert!(
481 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), None)
482 .is_ok()
483 );
484 }
485
486 #[test]
487 fn test_region_ids_are_unique() {
488 let mut bc = BorrowChecker::new();
489 let r1 = bc.enter_region();
490 let r2 = bc.enter_region();
491 assert_ne!(r1, r2);
492 bc.exit_region();
493 let r3 = bc.enter_region();
494 assert_ne!(r2, r3);
495 assert_ne!(r1, r3);
496 }
497
498 #[test]
499 fn test_error_carries_source_location() {
500 let mut bc = BorrowChecker::new();
501 let loc = SourceLocation::new(10, 5);
502 bc.create_borrow(0, 0, BorrowMode::Exclusive, span(), Some(loc.clone()))
503 .unwrap();
504 let err = bc.create_borrow(0, 1, BorrowMode::Exclusive, span(), Some(loc));
505 match err {
506 Err(ShapeError::SemanticError { location, .. }) => {
507 let loc = location.expect("error should carry source location");
508 assert_eq!(loc.line, 10);
509 assert_eq!(loc.column, 5);
510 }
511 other => panic!("expected SemanticError, got: {:?}", other),
512 }
513 }
514}