wayrs_utils/
shm_alloc.rs

1//! A simple "free list" shared memory allocator
2
3use std::fs::File;
4use std::io;
5use std::os::fd::AsFd;
6use std::sync::atomic::{AtomicU32, Ordering};
7use std::sync::Arc;
8
9use memmap2::MmapMut;
10
11use wayrs_client::global::BindError;
12use wayrs_client::object::Proxy;
13use wayrs_client::protocol::*;
14use wayrs_client::Connection;
15
16/// A simple "free list" shared memory allocator
17#[derive(Debug)]
18pub struct ShmAlloc {
19    state: ShmAllocState,
20}
21
22#[derive(Debug)]
23enum ShmAllocState {
24    Uninit(WlShm),
25    Init(InitShmPool),
26}
27
28#[derive(Debug)]
29struct InitShmPool {
30    pool: WlShmPool,
31    len: usize,
32    file: File,
33    mmap: MmapMut,
34    segments: Vec<Segment>,
35}
36
37#[derive(Debug)]
38struct Segment {
39    offset: usize,
40    len: usize,
41    refcnt: Arc<AtomicU32>,
42    buffer: Option<(WlBuffer, BufferSpec)>,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub struct BufferSpec {
47    pub width: u32,
48    pub height: u32,
49    pub stride: u32,
50    pub format: wl_shm::Format,
51}
52
53impl BufferSpec {
54    pub fn size(&self) -> usize {
55        self.stride as usize * self.height as usize
56    }
57}
58
59/// A `wl_buffer` with some metadata.
60#[derive(Debug)]
61pub struct Buffer {
62    spec: BufferSpec,
63    wl: WlBuffer,
64    refcnt: Arc<AtomicU32>,
65    wl_shm_pool: WlShmPool,
66    offset: usize,
67}
68
69impl ShmAlloc {
70    /// Bind `wl_shm` and create new [`ShmAlloc`].
71    pub fn bind<D>(conn: &mut Connection<D>) -> Result<Self, BindError> {
72        Ok(Self::new(conn.bind_singleton(1..=2)?))
73    }
74
75    /// Create new [`ShmAlloc`].
76    ///
77    /// This function takes the ownership of `wl_shm` and destroys it when it is no longer used.
78    pub fn new(wl_shm: WlShm) -> Self {
79        Self {
80            state: ShmAllocState::Uninit(wl_shm),
81        }
82    }
83
84    /// Allocate a new buffer.
85    ///
86    /// The underlying memory pool will be resized if needed. Previously released buffers are
87    /// reused whenever possible.
88    ///
89    /// See [`WlShmPool::create_buffer`] for more info.
90    pub fn alloc_buffer<D>(
91        &mut self,
92        conn: &mut Connection<D>,
93        spec: BufferSpec,
94    ) -> io::Result<(Buffer, &mut [u8])> {
95        // Note: `if let` does not work here because borrow checker is to dumb
96        if matches!(&self.state, ShmAllocState::Init(_)) {
97            let ShmAllocState::Init(pool) = &mut self.state else {
98                unreachable!()
99            };
100            return pool.alloc_buffer(conn, spec);
101        }
102
103        let &ShmAllocState::Uninit(wl_shm) = &self.state else {
104            unreachable!()
105        };
106
107        self.state = ShmAllocState::Init(InitShmPool::new(conn, wl_shm, spec.size())?);
108        if wl_shm.version() >= 2 {
109            wl_shm.release(conn);
110        }
111        let ShmAllocState::Init(pool) = &mut self.state else {
112            unreachable!()
113        };
114        pool.alloc_buffer(conn, spec)
115    }
116
117    /// Release all Wayland resources.
118    pub fn destroy<D>(self, conn: &mut Connection<D>) {
119        match self.state {
120            ShmAllocState::Uninit(wl_shm) => {
121                if wl_shm.version() >= 2 {
122                    wl_shm.release(conn);
123                }
124            }
125            ShmAllocState::Init(pool) => {
126                pool.pool.destroy(conn);
127            }
128        }
129    }
130}
131
132impl Buffer {
133    /// Get the underlying `wl_buffer`.
134    ///
135    /// This `wl_buffer` must be attached to exactly one surface, otherwise the memory may be
136    /// leaked or a panic may occur during [`Connection::dispatch_events`].
137    #[must_use = "memory is leaked if wl_buffer is not attached"]
138    pub fn into_wl_buffer(self) -> WlBuffer {
139        let wl = self.wl;
140        std::mem::forget(self);
141        wl
142    }
143
144    /// Create a `wl_buffer` that shares the same spec and underlying memory as `self`.
145    ///
146    /// This `wl_buffer` must be attached to exactly one surface, otherwise the memory may be
147    /// leaked or a panic may occur during [`Connection::dispatch_events`] or `self`'s drop.
148    ///
149    /// This method is usefull if you want to attach the same buffer to a number of surfaces. In
150    /// fact, this is the only correct way to do it unisg this library.
151    #[must_use = "memory is leaked if wl_buffer is not attached"]
152    pub fn duplicate<D>(&self, conn: &mut Connection<D>) -> WlBuffer {
153        self.refcnt.fetch_add(1, Ordering::AcqRel);
154        let refcnt = Arc::clone(&self.refcnt);
155        self.wl_shm_pool.create_buffer_with_cb(
156            conn,
157            self.offset as i32,
158            self.spec.width as i32,
159            self.spec.height as i32,
160            self.spec.stride as i32,
161            self.spec.format,
162            move |ctx| {
163                assert!(refcnt.fetch_sub(1, Ordering::AcqRel) > 0);
164                ctx.proxy.destroy(ctx.conn);
165            },
166        )
167    }
168
169    /// Get the spec of this buffer
170    pub fn spec(&self) -> BufferSpec {
171        self.spec
172    }
173}
174
175impl Drop for Buffer {
176    fn drop(&mut self) {
177        assert!(self.refcnt.fetch_sub(1, Ordering::AcqRel) > 0);
178    }
179}
180
181impl InitShmPool {
182    fn new<D>(conn: &mut Connection<D>, wl_shm: WlShm, size: usize) -> io::Result<InitShmPool> {
183        let file = shmemfdrs2::create_shmem(c"/wayrs_shm_pool")?;
184        file.set_len(size as u64)?;
185        let mmap = unsafe { MmapMut::map_mut(&file)? };
186
187        let fd_dup = file
188            .as_fd()
189            .try_clone_to_owned()
190            .expect("could not duplicate fd");
191
192        let pool = wl_shm.create_pool(conn, fd_dup, size as i32);
193
194        Ok(Self {
195            pool,
196            len: size,
197            file,
198            mmap,
199            segments: vec![Segment {
200                offset: 0,
201                len: size,
202                refcnt: Arc::new(AtomicU32::new(0)),
203                buffer: None,
204            }],
205        })
206    }
207
208    fn alloc_buffer<D>(
209        &mut self,
210        conn: &mut Connection<D>,
211        spec: BufferSpec,
212    ) -> io::Result<(Buffer, &mut [u8])> {
213        let segment_index = self.alloc_segment(conn, spec)?;
214        let segment = &mut self.segments[segment_index];
215
216        let (wl, spec) = *segment.buffer.get_or_insert_with(|| {
217            let seg_refcnt = Arc::clone(&segment.refcnt);
218            let wl = self.pool.create_buffer_with_cb(
219                conn,
220                segment.offset as i32,
221                spec.width as i32,
222                spec.height as i32,
223                spec.stride as i32,
224                spec.format,
225                move |_| {
226                    assert!(seg_refcnt.fetch_sub(1, Ordering::SeqCst) > 0);
227                    // We don't destroy the buffer here because it can be reused later
228                },
229            );
230            (wl, spec)
231        });
232
233        Ok((
234            Buffer {
235                spec,
236                wl,
237                refcnt: Arc::clone(&segment.refcnt),
238                wl_shm_pool: self.pool,
239                offset: segment.offset,
240            },
241            &mut self.mmap[segment.offset..][..segment.len],
242        ))
243    }
244
245    fn defragment<D>(&mut self, conn: &mut Connection<D>) {
246        let mut i = 0;
247        while i + 1 < self.segments.len() {
248            // `refcnt` cannot go from zero to anything else as it implies that the segment is not
249            // used anymore.
250            if self.segments[i].refcnt.load(Ordering::SeqCst) != 0
251                || self.segments[i + 1].refcnt.load(Ordering::SeqCst) != 0
252            {
253                i += 1;
254                continue;
255            }
256
257            if let Some(buffer) = self.segments[i].buffer.take() {
258                buffer.0.destroy(conn);
259            }
260            if let Some(buffer) = self.segments[i + 1].buffer.take() {
261                buffer.0.destroy(conn);
262            }
263
264            self.segments[i].len += self.segments[i + 1].len;
265
266            self.segments.remove(i + 1);
267        }
268    }
269
270    /// Resize the memmap, at least doubling the size.
271    fn resize<D>(&mut self, conn: &mut Connection<D>, new_len: usize) -> io::Result<()> {
272        if new_len > self.len {
273            self.len = usize::max(self.len * 2, new_len);
274            self.file.set_len(self.len as u64)?;
275            self.pool.resize(conn, self.len as i32);
276            self.mmap = unsafe { MmapMut::map_mut(&self.file)? };
277        }
278        Ok(())
279    }
280
281    /// Returns segment index, does not resize
282    fn try_alloc_in_place<D>(
283        &mut self,
284        conn: &mut Connection<D>,
285        len: usize,
286        spec: BufferSpec,
287    ) -> Option<usize> {
288        fn take_if_free(s: &Segment) -> bool {
289            s.refcnt
290                .compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
291                .is_ok()
292        }
293
294        // Find a segment with exact size
295        if let Some((i, segment)) = self
296            .segments
297            .iter_mut()
298            .enumerate()
299            .filter(|(_, s)| s.len == len)
300            .find(|(_, s)| take_if_free(s))
301        {
302            if let Some(buffer) = &segment.buffer {
303                if buffer.1 != spec {
304                    buffer.0.destroy(conn);
305                    segment.buffer = None;
306                }
307            }
308            return Some(i);
309        }
310
311        // Find a segment large enough
312        if let Some((i, segment)) = self
313            .segments
314            .iter_mut()
315            .enumerate()
316            .filter(|(_, s)| s.len > len)
317            .find(|(_, s)| take_if_free(s))
318        {
319            if let Some(buffer) = segment.buffer.take() {
320                buffer.0.destroy(conn);
321            }
322            let extra = segment.len - len;
323            let offset = segment.offset + len;
324            segment.len = len;
325            self.segments.insert(
326                i + 1,
327                Segment {
328                    offset,
329                    len: extra,
330                    refcnt: Arc::new(AtomicU32::new(0)),
331                    buffer: None,
332                },
333            );
334            return Some(i);
335        }
336
337        None
338    }
339
340    // Returns segment index
341    fn alloc_segment<D>(
342        &mut self,
343        conn: &mut Connection<D>,
344        spec: BufferSpec,
345    ) -> io::Result<usize> {
346        let len = spec.size();
347
348        if let Some(index) = self.try_alloc_in_place(conn, len, spec) {
349            return Ok(index);
350        }
351
352        self.defragment(conn);
353        if let Some(index) = self.try_alloc_in_place(conn, len, spec) {
354            return Ok(index);
355        }
356
357        let segments_len = match self.segments.last_mut() {
358            Some(segment)
359                if segment
360                    .refcnt
361                    .compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
362                    .is_ok() =>
363            {
364                if let Some(buffer) = segment.buffer.take() {
365                    buffer.0.destroy(conn);
366                }
367                segment.len = len;
368                let new_size = segment.offset + segment.len;
369                self.resize(conn, new_size)?;
370                new_size
371            }
372            _ => {
373                let offset = self.len;
374                self.resize(conn, self.len + len)?;
375                self.segments.push(Segment {
376                    offset,
377                    len,
378                    refcnt: Arc::new(AtomicU32::new(1)),
379                    buffer: None,
380                });
381                offset + len
382            }
383        };
384
385        // Create a segment if `self.resize()` over allocated
386        if segments_len > self.len {
387            self.segments.push(Segment {
388                offset: segments_len,
389                len: self.len - segments_len,
390                refcnt: Arc::new(AtomicU32::new(0)),
391                buffer: None,
392            });
393        }
394
395        Ok(self.segments.len() - 1)
396    }
397}