spsc_ringbuf_core/
shared_pool.rs

1use crate::ringbuf::{Consumer as RingBufConsumer, Producer as RingBufProducer, RingBuf};
2use crate::shared_singleton::SharedSingleton;
3
4#[derive(Debug)]
5pub enum SharedPoolError {
6    PoolFull,
7    AllocBufFull,
8    ReturnBufFull,
9    AllocBufEmpty,
10    PayloadNotConsumerOwned,
11    AlreadySplit,
12}
13
14#[derive(Clone, Copy)]
15pub struct PoolIndex<const N: usize>(u32);
16
17// Get usize from PoolIndex<N>
18impl<const N: usize> TryFrom<PoolIndex<N>> for usize {
19    type Error = ();
20
21    fn try_from(value: PoolIndex<N>) -> Result<Self, Self::Error> {
22        if value.0 >= N as u32 {
23            // Invalid, cannot be referenced
24            Err(())
25        } else {
26            // Ok, can be referenced
27            Ok(value.0 as usize)
28        }
29    }
30}
31
32impl<const N: usize>  PoolIndex<N> {
33    pub fn is_valid(&self) -> bool {
34        self.0 < N as u32
35    }
36}
37
38pub trait HasPoolIdx<const N: usize> {
39    fn get_pool_idx(&self) -> PoolIndex<N>;
40    fn set_pool_idx(&mut self, pindex: PoolIndex<N>);
41}
42
43pub struct Producer<'a, T, Q: HasPoolIdx<N>, const N: usize, const M: usize> {
44    // Producer handle for the command allocation
45    pub alloc_prod: RingBufProducer<'a, Q, M>,
46    // Consumer handle for the return ringbuf
47    pub return_cons: RingBufConsumer<'a, Q, M>,
48    // Reference to the payload pool
49    pool_ref: &'a [SharedSingleton<T>; N],
50}
51
52impl<'a, T, Q: HasPoolIdx<N>, const N: usize, const M: usize> Producer<'a, T, Q, N, M> {
53    pub const fn new(
54        alloc_prod: RingBufProducer<'a, Q, M>,
55        return_cons: RingBufConsumer<'a, Q, M>,
56        pool_ref: &'a [SharedSingleton<T>; N],
57    ) -> Self {
58        Producer {
59            alloc_prod,
60            return_cons,
61            pool_ref,
62        }
63    }
64
65    // Internal - get an item from the pool
66    fn take_pool_item(&mut self) -> PoolIndex<N> {
67        // Check the return queue
68        if let Some(item) = self.return_cons.reader_front() {
69            // If there's a return item it must be a valid
70            // pool index
71            let payload_idx = usize::try_from(item.get_pool_idx()).unwrap();
72
73            // Assert location indicated as free is actually vacant
74            assert!(self.pool_ref[payload_idx].is_vacant());
75
76            // Pop the return queue
77            assert!(self.return_cons.pop().is_ok());
78
79            return PoolIndex(payload_idx as u32);
80        }
81        // Otherwise nothing is valid
82        PoolIndex(N as u32)
83    }
84
85    // Stage item for write without payload
86    pub fn stage(&mut self) -> Option<&mut Q> {
87        if let Some(item) = self.alloc_prod.writer_front() {
88            item.set_pool_idx(PoolIndex::<N>(N as u32));
89
90            Some(item)
91        } else {
92            None
93        }
94    }
95
96    // Stage a command buffer and an accompanying payload from the pool
97    // Return a pair of mutable references if successful
98    pub fn stage_with_payload(&mut self) -> Result<(&mut Q, &SharedSingleton<T>), SharedPoolError> {
99        if let Ok(idx) = usize::try_from(self.take_pool_item()) {
100            let payload = &self.pool_ref[idx];
101
102            if let Some(item) = self.alloc_prod.writer_front() {
103                item.set_pool_idx(PoolIndex::<N>(idx as u32));
104
105                Ok((item, payload))
106            } else {
107                Err(SharedPoolError::AllocBufFull)
108            }
109        } else {
110            Err(SharedPoolError::PoolFull)
111        }
112    }
113
114    // Commit the command. If command can contain payload, check
115    // if the payload has already been passed to the consumer.
116    pub fn commit(&mut self) -> Result<(), SharedPoolError> {
117        // In payload has been allocated, check if passed to consumer.
118        if let Some(item) = self.alloc_prod.writer_front() {
119            if let Ok(idx) = usize::try_from(item.get_pool_idx()) {
120                if self.pool_ref[idx].try_read().is_none() {
121                    // Payload index is set but not passed to consumer
122                    return Err(SharedPoolError::PayloadNotConsumerOwned);
123                }
124            }
125        }
126        // commit the command queue. Map the only possible commit error (BufFull)
127        // to SharedPoolError::AllocBufFull
128        self.alloc_prod
129            .commit()
130            .map_err(|_| SharedPoolError::AllocBufFull)
131    }
132}
133
134pub struct Consumer<'a, T, Q: HasPoolIdx<N>, const N: usize, const M: usize> {
135    // Consumer handle for the command allocation
136    pub alloc_cons: RingBufConsumer<'a, Q, M>,
137    // Producer handle for the return ringbuf
138    pub return_prod: RingBufProducer<'a, Q, M>,
139    // Reference to the payload pool
140    pool_ref: &'a [SharedSingleton<T>; N],
141}
142
143impl<'a, T, Q: HasPoolIdx<N>, const N: usize, const M: usize> Consumer<'a, T, Q, N, M> {
144    pub fn peek_with_payload(&self) -> (Option<&Q>, Option<&SharedSingleton<T>>) {
145        let ret = self.alloc_cons.reader_front();
146
147        match ret {
148            Some(message) => {
149                let has_idx = message.get_pool_idx();
150                if let Ok(idx) = usize::try_from(has_idx) {
151                    (ret, Some(&self.pool_ref[idx]))
152                }
153                else {
154                    (ret, None)
155                }
156            }
157            _ => (None, None)
158        }
159    }
160
161    pub fn peek(&self) -> Option<&Q> {
162        self.alloc_cons.reader_front()
163    }
164
165    pub fn read_pool_item(&self, pidx: PoolIndex<N>) -> Option<&SharedSingleton<T>> {
166        if let Ok(idx) = usize::try_from(pidx) {
167            Some(&self.pool_ref[idx])
168        }
169        else {
170            None
171        }
172    }
173
174    pub fn pop(&mut self) -> Result<(), SharedPoolError> {
175        self.alloc_cons
176            .pop()
177            .map_err(|_| SharedPoolError::AllocBufEmpty)
178    }
179
180    // Return a payload location in the pool back to the Producer
181    pub fn return_payload(&mut self, pidx: PoolIndex<N>) -> Result<(), SharedPoolError> {
182        // Allocation a location in the return queue
183        if let Some(re) = self.return_prod.writer_front() {
184            // Assert returned payload idx is at least valid
185            // That's the best we can do from consumer side
186            assert!(pidx.is_valid());
187
188            // pidx is asserted above to be valid
189            // pidx.0 is private, hence user cannot access the value
190            // directly. Also pool_ref is private
191            assert!(self.pool_ref[pidx.0 as usize].is_vacant());
192
193            re.set_pool_idx(pidx);
194
195            self.return_prod
196                .commit()
197                .map_err(|_| SharedPoolError::ReturnBufFull)
198        } else {
199            Err(SharedPoolError::ReturnBufFull)
200        }
201    }
202}
203
204pub struct SharedPool<T, Q: HasPoolIdx<N>, const N: usize, const M: usize> {
205    alloc_rbuf: RingBuf<Q, M>,
206    return_rbuf: RingBuf<Q, M>,
207    pool: [SharedSingleton<T>; N],
208}
209
210unsafe impl<T, Q: HasPoolIdx<N>, const N: usize, const M: usize> Sync for SharedPool<T, Q, N, M> {}
211
212impl<T, Q: HasPoolIdx<N>, const N: usize, const M: usize> SharedPool<T, Q, N, M> {
213    // new
214    // initialize return_rbuf to be full
215    // return to be empty
216
217    const OK: () = assert!(M >= N, "Ringbuf capacity (M) must be >= Pool Capacity (N)");
218    
219    #[allow(clippy::let_unit_value)]
220    pub const fn new() -> Self {
221        let _: () = SharedPool::<T, Q, N, M>::OK;
222        SharedPool {
223            alloc_rbuf: RingBuf::new(),
224            return_rbuf: RingBuf::new(),
225            pool: [SharedSingleton::INIT_0; N],
226        }
227    }
228
229    // Return the producer, once in life time
230    pub fn split_prod(&self) -> Result<Producer<'_, T, Q, N, M>, SharedPoolError> {
231        if self.alloc_rbuf.has_split_prod() || self.return_rbuf.has_split_cons() {
232            // Can only split once in life time
233            Err(SharedPoolError::AlreadySplit)
234        } else {
235            // Split the allocation and return ring buffers to their
236            // corresponding producers and consumers. Not expected to fail
237            // since this is already protected by our own has split flag
238            let alloc_p = self.alloc_rbuf.split_prod().unwrap();
239            let ret_c = self.return_rbuf.split_cons().unwrap();
240
241            // Distribute the producers and consumers to the final
242            // Producer and Consumer wrappers
243            let producer = Producer {
244                alloc_prod: alloc_p,
245                return_cons: ret_c,
246                pool_ref: &self.pool,
247            };
248            Ok(producer)
249        }
250    }
251
252    // Return the consumer, once in life time
253    pub fn split_cons(&self) -> Result<Consumer<'_, T, Q, N, M>, SharedPoolError> {
254        if self.alloc_rbuf.has_split_cons() || self.return_rbuf.has_split_prod() {
255            // Can only split once in life time
256            Err(SharedPoolError::AlreadySplit)
257        } else {
258            // Split the allocation and return ring buffers to their
259            // corresponding producers and consumers. Not expected to fail
260            // since this is already protected by our own has split flag
261            let alloc_c = self.alloc_rbuf.split_cons().unwrap();
262            let mut ret_p = self.return_rbuf.split_prod().unwrap();
263
264            // Pre-fill the return queue with all the pool indices
265            for i in 0..N {
266                // Can unwrap here as we don't expect this fail
267                let item = ret_p.writer_front().unwrap();
268                item.set_pool_idx(PoolIndex(i as u32));
269                ret_p.commit().unwrap();
270            }
271
272            let consumer = Consumer {
273                alloc_cons: alloc_c,
274                return_prod: ret_p,
275                pool_ref: &self.pool,
276            };
277            Ok(consumer)
278        }
279    }
280    // Split both producer and consumer handle together
281    pub fn split(&self) -> Result<(Producer<'_, T, Q, N, M>, Consumer<'_, T, Q, N, M>), SharedPoolError> {
282
283        match (self.split_prod(), self.split_cons())  {
284            (Ok(prod), Ok(cons)) => Ok((prod, cons)),
285            _ => Err(SharedPoolError::AlreadySplit)
286        }
287    }
288
289    pub fn num_free(&self) -> u32 {
290        self.return_rbuf.len()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    const POOL_DEPTH: usize = 16;
299    pub struct Message {
300        id: u32,
301        payload: PoolIndex<POOL_DEPTH>,
302    }
303
304    impl HasPoolIdx<POOL_DEPTH> for Message {
305        fn get_pool_idx(&self) -> PoolIndex<POOL_DEPTH> {
306            self.payload
307        }
308        fn set_pool_idx(&mut self, pindex: PoolIndex<POOL_DEPTH>) {
309            self.payload = pindex
310        }
311    }
312
313    pub struct Payload {
314        value: u32,
315    }
316
317    static SHARED_POOL: SharedPool<Payload, Message, 16, 32> = SharedPool {
318        alloc_rbuf: RingBuf::INIT_0,
319        return_rbuf: RingBuf::INIT_0,
320        pool: [SharedSingleton::<Payload>::INIT_0; 16],
321    };
322
323    #[test]
324    fn test_basic() {
325        if let Ok((mut producer, mut consumer)) = SHARED_POOL.split() {
326
327            // Allocate the actual command
328            let (message, payload) = producer.stage_with_payload().unwrap();
329            
330            // Update the message
331            message.id = 41;
332            let raw = payload.try_write().unwrap();
333            raw.value = 42;
334            // Pass the payload
335            payload.write_done().unwrap();
336
337            // Commit 
338            assert!(producer.commit().is_ok());
339
340            // Test consumer can see it
341            assert!(consumer.peek_with_payload().0.is_some());
342
343            let (recvd, payload) = consumer.peek_with_payload();
344
345            assert!(recvd.unwrap().id == 41);
346
347            assert!(payload.unwrap().try_read().unwrap().value == 42);
348
349            // Return the payload item to producer
350            assert!(payload.unwrap().read_done().is_ok());
351
352            // Return the payload location back to the queue
353            assert!(consumer.return_payload(recvd.unwrap().get_pool_idx()).is_ok());
354
355            assert!(consumer.pop().is_ok());
356            
357            let (message, payload) = producer.stage_with_payload().unwrap();
358
359            // Update the message
360            message.id = 43;
361            let raw = payload.try_write().unwrap();
362            raw.value = 44;
363            // Pass the payload
364            payload.write_done().unwrap();
365
366            // Commit 
367            assert!(producer.commit().is_ok());
368
369            // Peek only so we can return the message while holding
370            // the payload
371            let recvd = consumer.peek().unwrap();
372
373            let payload_idx = recvd.get_pool_idx();
374
375            // Return the message 
376            assert!(consumer.pop().is_ok());
377
378            // Keep payload access
379            let payload = consumer.read_pool_item(payload_idx).unwrap();
380
381            assert!(payload.try_read().unwrap().value == 44);
382
383            assert!(payload.read_done().is_ok());
384
385            assert!(consumer.return_payload(payload_idx).is_ok());
386
387
388        } else {
389            panic!("first split failed!");
390        }
391    }
392
393
394}