1use std::fs::{File, OpenOptions};
7use std::io;
8use std::os::unix::fs::PermissionsExt;
9use std::os::unix::io::AsRawFd;
10use std::path::{Path, PathBuf};
11
12use crate::Region;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum FileCleanup {
17 Manual,
19 Auto,
23}
24
25pub struct MmapRegion {
29 ptr: *mut u8,
31 len: usize,
33 #[allow(dead_code)]
35 file: File,
36 path: PathBuf,
38 owns_file: bool,
40}
41
42impl MmapRegion {
43 pub fn create(path: &Path, size: usize, cleanup: FileCleanup) -> io::Result<Self> {
51 if size == 0 {
52 return Err(io::Error::new(
53 io::ErrorKind::InvalidInput,
54 "size must be > 0",
55 ));
56 }
57
58 let file = OpenOptions::new()
60 .read(true)
61 .write(true)
62 .create(true)
63 .truncate(true)
64 .open(path)
65 .map_err(|e| {
66 let msg = std::format!("Failed to create SHM file at {}: {}", path.display(), e);
67 io::Error::new(e.kind(), msg)
68 })?;
69
70 file.set_permissions(std::fs::Permissions::from_mode(0o600))?;
72
73 file.set_len(size as u64)?;
75
76 let ptr = unsafe {
78 libc::mmap(
79 std::ptr::null_mut(),
80 size,
81 libc::PROT_READ | libc::PROT_WRITE,
82 libc::MAP_SHARED,
83 file.as_raw_fd(),
84 0,
85 )
86 };
87
88 if ptr == libc::MAP_FAILED {
89 return Err(io::Error::last_os_error());
90 }
91
92 let path_buf = path.to_path_buf();
93
94 if cleanup == FileCleanup::Auto {
98 std::fs::remove_file(&path_buf)?;
99 }
100
101 Ok(Self {
102 ptr: ptr as *mut u8,
103 len: size,
104 file,
105 path: path_buf,
106 owns_file: cleanup == FileCleanup::Manual,
107 })
108 }
109
110 pub fn attach(path: &Path) -> io::Result<Self> {
117 let file = OpenOptions::new()
119 .read(true)
120 .write(true)
121 .open(path)
122 .map_err(|e| {
123 let msg = std::format!("Failed to open SHM file at {}: {}", path.display(), e);
124 io::Error::new(e.kind(), msg)
125 })?;
126
127 let metadata = file.metadata()?;
129 let size = metadata.len() as usize;
130
131 if size == 0 {
132 return Err(io::Error::new(
133 io::ErrorKind::InvalidData,
134 "segment file is empty",
135 ));
136 }
137
138 let ptr = unsafe {
140 libc::mmap(
141 std::ptr::null_mut(),
142 size,
143 libc::PROT_READ | libc::PROT_WRITE,
144 libc::MAP_SHARED,
145 file.as_raw_fd(),
146 0,
147 )
148 };
149
150 if ptr == libc::MAP_FAILED {
151 return Err(io::Error::last_os_error());
152 }
153
154 Ok(Self {
155 ptr: ptr as *mut u8,
156 len: size,
157 file,
158 path: path.to_path_buf(),
159 owns_file: false, })
161 }
162
163 #[inline]
165 pub fn region(&self) -> Region {
166 unsafe { Region::from_raw(self.ptr, self.len) }
168 }
169
170 #[inline]
172 pub fn len(&self) -> usize {
173 self.len
174 }
175
176 #[inline]
178 pub fn is_empty(&self) -> bool {
179 self.len == 0
180 }
181
182 #[inline]
184 pub fn path(&self) -> &Path {
185 &self.path
186 }
187
188 pub fn take_ownership(&mut self) {
192 self.owns_file = true;
193 }
194
195 pub fn release_ownership(&mut self) {
199 self.owns_file = false;
200 }
201
202 pub fn resize(&mut self, new_size: usize) -> io::Result<()> {
214 if new_size < self.len {
215 return Err(io::Error::new(
216 io::ErrorKind::InvalidInput,
217 "shrinking is not supported",
218 ));
219 }
220 if new_size == self.len {
221 return Ok(()); }
223
224 self.file.set_len(new_size as u64)?;
226
227 let unmap_result = unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.len) };
229 if unmap_result != 0 {
230 return Err(io::Error::last_os_error());
231 }
232
233 let ptr = unsafe {
235 libc::mmap(
236 std::ptr::null_mut(),
237 new_size,
238 libc::PROT_READ | libc::PROT_WRITE,
239 libc::MAP_SHARED,
240 self.file.as_raw_fd(),
241 0,
242 )
243 };
244
245 if ptr == libc::MAP_FAILED {
246 return Err(io::Error::last_os_error());
247 }
248
249 self.ptr = ptr as *mut u8;
250 self.len = new_size;
251 Ok(())
252 }
253
254 pub fn check_and_remap(&mut self) -> io::Result<bool> {
263 let file_size = self.file.metadata()?.len() as usize;
264 if file_size > self.len {
265 self.resize(file_size)?;
266 Ok(true)
267 } else {
268 Ok(false)
269 }
270 }
271}
272
273impl Drop for MmapRegion {
274 fn drop(&mut self) {
275 unsafe {
277 libc::munmap(self.ptr as *mut libc::c_void, self.len);
278 }
279
280 if self.owns_file {
283 let _ = std::fs::remove_file(&self.path);
284 }
285 }
286}
287
288unsafe impl Send for MmapRegion {}
291unsafe impl Sync for MmapRegion {}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_create_and_attach() {
299 let dir = tempfile::tempdir().unwrap();
300 let path = dir.path().join("test.shm");
301
302 let region1 = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
304 assert_eq!(region1.len(), 4096);
305 assert!(path.exists());
306
307 let data = region1.region();
309 unsafe {
310 std::ptr::write(data.as_ptr(), 0x42);
311 std::ptr::write(data.as_ptr().add(1), 0x43);
312 }
313
314 let region2 = MmapRegion::attach(&path).unwrap();
316 assert_eq!(region2.len(), 4096);
317
318 let data2 = region2.region();
320 unsafe {
321 assert_eq!(std::ptr::read(data2.as_ptr()), 0x42);
322 assert_eq!(std::ptr::read(data2.as_ptr().add(1)), 0x43);
323 }
324 }
325
326 #[test]
327 fn test_cleanup_on_drop() {
328 let dir = tempfile::tempdir().unwrap();
329 let path = dir.path().join("cleanup.shm");
330
331 {
332 let _region = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
333 assert!(path.exists());
334 }
335
336 assert!(!path.exists());
338 }
339
340 #[test]
341 fn test_attached_does_not_cleanup() {
342 let dir = tempfile::tempdir().unwrap();
343 let path = dir.path().join("attached.shm");
344
345 let owner = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
346
347 {
348 let _attached = MmapRegion::attach(&path).unwrap();
349 assert!(path.exists());
350 }
351
352 assert!(path.exists());
354
355 drop(owner);
357 assert!(!path.exists());
358 }
359
360 #[test]
361 fn test_shared_writes() {
362 let dir = tempfile::tempdir().unwrap();
363 let path = dir.path().join("shared.shm");
364
365 let region1 = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
366 let region2 = MmapRegion::attach(&path).unwrap();
367
368 let data2 = region2.region();
370 unsafe {
371 std::ptr::write(data2.as_ptr().add(100), 0xAB);
372 }
373
374 let data1 = region1.region();
376 unsafe {
377 assert_eq!(std::ptr::read(data1.as_ptr().add(100)), 0xAB);
378 }
379 }
380
381 #[test]
382 fn test_permissions() {
383 let dir = tempfile::tempdir().unwrap();
384 let path = dir.path().join("perms.shm");
385
386 let _region = MmapRegion::create(&path, 1024, FileCleanup::Manual).unwrap();
387
388 let metadata = std::fs::metadata(&path).unwrap();
389 let mode = metadata.permissions().mode() & 0o777;
390 assert_eq!(mode, 0o600);
391 }
392
393 #[test]
394 fn test_zero_size_rejected() {
395 let dir = tempfile::tempdir().unwrap();
396 let path = dir.path().join("zero.shm");
397
398 let result = MmapRegion::create(&path, 0, FileCleanup::Manual);
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_resize_grows_region() {
404 let dir = tempfile::tempdir().unwrap();
405 let path = dir.path().join("resize.shm");
406
407 let mut region = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
408 assert_eq!(region.len(), 4096);
409
410 unsafe {
412 std::ptr::write(region.region().as_ptr(), 0xAB);
413 }
414
415 region.resize(8192).unwrap();
417 assert_eq!(region.len(), 8192);
418
419 unsafe {
421 assert_eq!(std::ptr::read(region.region().as_ptr()), 0xAB);
422 }
423
424 unsafe {
426 std::ptr::write(region.region().as_ptr().add(5000), 0xCD);
427 assert_eq!(std::ptr::read(region.region().as_ptr().add(5000)), 0xCD);
428 }
429 }
430
431 #[test]
432 fn test_resize_shrink_rejected() {
433 let dir = tempfile::tempdir().unwrap();
434 let path = dir.path().join("shrink.shm");
435
436 let mut region = MmapRegion::create(&path, 8192, FileCleanup::Manual).unwrap();
437 let result = region.resize(4096);
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn test_check_and_remap() {
443 let dir = tempfile::tempdir().unwrap();
444 let path = dir.path().join("remap.shm");
445
446 let mut owner = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
448
449 let mut guest = MmapRegion::attach(&path).unwrap();
451 assert_eq!(guest.len(), 4096);
452
453 owner.resize(8192).unwrap();
455
456 let remapped = guest.check_and_remap().unwrap();
458 assert!(remapped);
459 assert_eq!(guest.len(), 8192);
460
461 let remapped2 = guest.check_and_remap().unwrap();
463 assert!(!remapped2);
464 }
465
466 #[test]
467 fn test_resize_preserves_shared_data() {
468 let dir = tempfile::tempdir().unwrap();
469 let path = dir.path().join("shared_resize.shm");
470
471 let mut owner = MmapRegion::create(&path, 4096, FileCleanup::Manual).unwrap();
472 let mut guest = MmapRegion::attach(&path).unwrap();
473
474 unsafe {
476 std::ptr::write(owner.region().as_ptr().add(100), 0x42);
477 }
478
479 unsafe {
481 assert_eq!(std::ptr::read(guest.region().as_ptr().add(100)), 0x42);
482 }
483
484 owner.resize(8192).unwrap();
486
487 guest.check_and_remap().unwrap();
489
490 unsafe {
492 assert_eq!(std::ptr::read(guest.region().as_ptr().add(100)), 0x42);
493 }
494
495 unsafe {
497 std::ptr::write(owner.region().as_ptr().add(5000), 0x99);
498 }
499
500 unsafe {
502 assert_eq!(std::ptr::read(guest.region().as_ptr().add(5000)), 0x99);
503 }
504 }
505}