shared_memory_allocator/
lib.rs

1#![feature(allocator_api)]
2#![feature(vec_into_raw_parts)]
3#![warn(clippy::pedantic)]
4
5//! An extremely unsafe experiment in writing a custom allocator to use linux shared memory.
6
7use std::alloc::{AllocError, Layout};
8use std::collections::HashMap;
9use std::io::{Read, Seek, Write};
10use std::mem::MaybeUninit;
11use std::ops::Range;
12use std::ptr::NonNull;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::{Arc, Mutex};
15
16const COUNTER_SUFFIX: &str = "_count";
17type Space = Range<usize>;
18
19// Contains data with infomation regarding state of the shared memory
20struct SharedMemoryDescription {
21    id: String,
22    // Overall address space
23    address_space: Space,
24    // Free memory spaces
25    free: Vec<Space>,
26}
27impl SharedMemoryDescription {
28    fn new(address_space: Space, id: &str) -> Self {
29        let temp = address_space.clone();
30        Self {
31            id: String::from(id),
32            address_space,
33            free: vec![temp],
34        }
35    }
36}
37impl Drop for SharedMemoryDescription {
38    fn drop(&mut self) {
39        // Detach shared memory
40        reset_err();
41        let x = unsafe { libc::shmdt(self.address_space.start as *const libc::c_void) };
42        dbg!(x);
43        check_err();
44
45        // De-crement the count of processes accessing this shared memory
46        let mut file = std::fs::OpenOptions::new()
47            .read(true)
48            .write(true)
49            .open(format!("{}{COUNTER_SUFFIX}", self.id))
50            .unwrap();
51        let mut count = [u8::default()];
52        file.read_exact(&mut count).unwrap();
53        file.seek(std::io::SeekFrom::Start(0)).unwrap();
54        let new_count = count[0] - 1;
55
56        // Only 1 process needs to tell the OS to de-allocate the shared memory
57        if new_count == 0 {
58            let mut shmid_file = std::fs::File::open(&self.id).unwrap();
59            let mut buf = [0; std::mem::size_of::<i32>()];
60            shmid_file.read_exact(&mut buf).unwrap();
61            let shmid = i32::from_ne_bytes(buf);
62
63            // De-allocate shared memory
64            reset_err();
65            let x = unsafe { libc::shmctl(shmid, libc::IPC_RMID, std::ptr::null_mut()) };
66            dbg!(x);
67            check_err();
68
69            // Since the second process closes last this one deletes the file
70            std::fs::remove_file(&self.id).unwrap();
71            std::fs::remove_file(format!("{}{COUNTER_SUFFIX}", self.id)).unwrap();
72        } else {
73            file.write_all(&[new_count]).unwrap();
74        }
75    }
76}
77/// An allocator implementing [`std::alloc::Allocator`] which allocates items in linux shared
78/// memory.
79#[derive(Clone)]
80pub struct SharedAllocator(Arc<Mutex<SharedMemoryDescription>>);
81
82impl SharedAllocator {
83    /// Construct an alloctor storing the shared memory id in a file at `shmid_path`.
84    ///
85    /// Constructing multiple allocators with the same `shmid_path` will use the same shared memory.
86    ///
87    /// After constructing the first allocator of a given `shmid_path` constructing new allocators
88    /// with the same `shmid_path` is the same as cloning the original allocator.
89    ///
90    /// Allocators with the same `shmid_path` across processes access the same memory although are
91    /// unaware of the presence of items created with allocators from other processes.
92    /// If 2 or more processes are allocating items in the same shared memory it is likely memory
93    /// will be corrupted.
94    ///
95    /// When constructing an allocator with the same `shmid_path` as an existing allocator the value
96    /// of `size` will not be used for allocating shared memory but rather attaching shared memory.
97    /// As such if the value shoould not be larger than the shared memory initially allocated (by
98    /// the first allocator constructed with the `shmid_path`)
99    ///
100    /// This library does not currently implement a mechanism for communicating the layout of the
101    /// shared memory between allocators in different processes.
102    /// You have to do this manually, in the example of the simplest use case, 1 process stores a
103    /// large object in shared memory, when this process wishes to handoff to a newer process it
104    /// sends the address of this object in the shared memory over a
105    /// [`std::os::unix::net::UnixDatagram`] to the new process, which can pickup this object more
106    /// quickly than if it had to be serialized and sent over a [`std::os::unix::net::UnixStream`]
107    /// for example
108    ///
109    /// # Panics
110    ///
111    /// For a whole lot of reasons. This is not a production ready library, it is a toy, treat it as
112    /// such.
113    #[must_use]
114    pub fn new(shmid_path: &str, size: usize) -> Self {
115        type MemoryDescriptorMap = HashMap<i32, Arc<Mutex<SharedMemoryDescription>>>;
116        static mut SHARED_MEMORY_DESCRIPTORS: MaybeUninit<Arc<Mutex<MemoryDescriptorMap>>> =
117            MaybeUninit::uninit();
118        static SHARED: AtomicBool = AtomicBool::new(false);
119        let first = !std::path::Path::new(shmid_path).exists();
120        dbg!(first);
121        // If the shared memory id file doesn't exist, this is the first process to use this shared
122        // memory. Thus we must allocate the shared memory.
123        if first {
124            // Allocate shared memory
125            reset_err();
126            let shared_mem_id = unsafe { libc::shmget(libc::IPC_PRIVATE, size, libc::IPC_CREAT) };
127            dbg!(shared_mem_id);
128            check_err();
129            // We simply save the shared memory id to a file for now
130            let mut shmid_file = std::fs::File::create(shmid_path).unwrap();
131            shmid_file.write_all(&shared_mem_id.to_ne_bytes()).unwrap();
132
133            // We create a counter (like a counter in an Arc) to keep the shared memory alive as
134            // long as atleast 1 process is using it.
135            let mut count_file =
136                std::fs::File::create(&format!("{shmid_path}{COUNTER_SUFFIX}")).unwrap();
137            count_file.write_all(&1u8.to_ne_bytes()).unwrap();
138        }
139
140        // Gets shared memory id
141        let mut shmid_file = std::fs::File::open(shmid_path).unwrap();
142        let mut shmid_bytes = [0; 4];
143        shmid_file.read_exact(&mut shmid_bytes).unwrap();
144        // dbg!(shmid_bytes);
145        let shmid = i32::from_ne_bytes(shmid_bytes);
146        dbg!(shmid);
147
148        // If first shared allocator
149        if SHARED.swap(true, Ordering::SeqCst) {
150            unsafe {
151                SHARED_MEMORY_DESCRIPTORS.write(Arc::new(Mutex::new(HashMap::new())));
152            }
153        }
154
155        let map_ref = unsafe { SHARED_MEMORY_DESCRIPTORS.assume_init_mut() };
156        let mut guard = map_ref.lock().unwrap();
157        // If a shared memory description was found, simply create the allocator pointing to this
158        // shared memory.
159        if let Some(shared_memory_description) = guard.get(&shmid) {
160            Self(shared_memory_description.clone())
161        }
162        // If the map of memory descriptions doesn't contain one for this shared memory id this is
163        // the first `SharedAllocator` instance created for this process, and the first time we are
164        // trying to access this shared memory.
165        // Thus here we want to attach the shared memory to this process (creating the shared memory
166        // desription as we do this).
167        else {
168            // Attach shared memory
169            reset_err();
170            let shared_mem_ptr = unsafe { libc::shmat(shmid, std::ptr::null(), 0) };
171            dbg!(shared_mem_ptr);
172            check_err();
173            let addr = shared_mem_ptr as usize;
174            // Create memory desicrption
175            let shared_memory_description = Arc::new(Mutex::new(SharedMemoryDescription::new(
176                addr..addr + size,
177                shmid_path,
178            )));
179            guard.insert(shmid, shared_memory_description.clone());
180            // Return allocator
181            Self(shared_memory_description)
182        }
183    }
184}
185
186unsafe impl std::alloc::Allocator for SharedAllocator {
187    fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
188        let mut guard = self.0.lock().unwrap();
189
190        // We find free space large enough
191        let space_opt = guard
192            .free
193            .iter_mut()
194            .find(|space| (space.end - space.start) >= layout.size());
195        let space = match space_opt {
196            Some(x) => x,
197            // In the future when a space cannot be found of this size we should defragment the
198            // address space to produce a large enough contgious space, after this we should attempt
199            // to allocate more shared memory
200            None => unimplemented!(),
201        };
202        // We shrink the space
203        assert!(space.end >= space.start + layout.size());
204        let addr = space.start;
205        space.start += layout.size();
206
207        // We alloc the required memory
208        let ptr = addr as *mut u8;
209        let nonnull_ptr = unsafe {
210            NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(ptr, layout.size()))
211        };
212        Ok(nonnull_ptr)
213    }
214
215    #[allow(clippy::significant_drop_in_scrutinee)]
216    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
217        let mut guard = self.0.lock().unwrap();
218
219        let start = ptr.as_ptr() as usize;
220        let end = start + layout.size();
221
222        if guard.free[0].start >= end {
223            if guard.free[0].end == start {
224                guard.free[0].start = start;
225            } else {
226                guard.free.insert(0, start..end);
227            }
228        }
229        for i in 1..guard.free.len() {
230            if guard.free[i].start >= end {
231                match (guard.free[i - 1].end == start, guard.free[i].start == end) {
232                    (true, true) => {
233                        guard.free[i - 1].end = guard.free[i].end;
234                        guard.free.remove(i);
235                    }
236                    (true, false) => {
237                        guard.free[i - 1].end = end;
238                    }
239                    (false, true) => {
240                        guard.free[i].start = start;
241                    }
242                    (false, false) => {
243                        guard.free.insert(i, start..end);
244                    }
245                }
246            }
247        }
248    }
249}
250
251fn reset_err() {
252    unsafe { *libc::__errno_location() = 0 };
253}
254fn check_err() {
255    let errno = unsafe { libc::__errno_location() };
256    let errno = unsafe { *errno };
257    if errno != 0 {
258        let string = std::ffi::CString::new("message").unwrap();
259        unsafe { libc::perror(string.as_ptr()) };
260        panic!("Error occured, error code: {errno}");
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use std::time::Duration;
267
268    use super::*;
269
270    const PAGE_SIZE: usize = 4096; // 4kb
271    const SIZE: usize = PAGE_SIZE; // 4mb
272    #[test]
273    fn main() {
274        let shmid_file: &str = "/tmp/shmid";
275
276        let first = !std::path::Path::new(shmid_file).exists();
277        dbg!(first);
278
279        #[allow(clippy::same_item_push)]
280        if first {
281            let shared_allocator = SharedAllocator::new(shmid_file, SIZE);
282            let mut x = Vec::<u8, _>::new_in(shared_allocator.clone());
283            for _ in 0..10 {
284                x.push(7);
285            }
286            dbg!(x.into_raw_parts());
287            let mut y = Vec::<u8, _>::new_in(shared_allocator.clone());
288            for _ in 0..20 {
289                y.push(69);
290            }
291            dbg!(y.into_raw_parts());
292            let mut z = Vec::<u8, _>::new_in(shared_allocator);
293            for _ in 0..5 {
294                z.push(220);
295            }
296            dbg!(z.into_raw_parts());
297        }
298
299        std::thread::sleep(Duration::from_secs(20));
300    }
301}