Skip to main content

rust_mqtt/
buffer.rs

1//! Contains the trait the client uses to store slices of memory and basic implementations.
2
3#[cfg(feature = "alloc")]
4pub use alloc::AllocBuffer;
5
6#[cfg(feature = "bump")]
7pub use bump::{BumpBuffer, InsufficientSpace};
8
9use crate::bytes::Bytes;
10
11/// A trait to describe anything that can allocate memory.
12///
13/// Returned memory can be borrowed or owned. Either way, it is bound by the `'a`
14/// lifetime - usually just the lifetime of the underlying buffer.
15///
16/// The client does not store any references to memory returned by this provider.
17pub trait BufferProvider<'a> {
18    /// The type returned from a successful buffer provision.
19    /// Must implement [`AsMut`] so that it can be borrowed mutably right after allocation for
20    /// initialization and [`Into`] for storing as [`Bytes`].
21    type Buffer: AsMut<[u8]> + Into<Bytes<'a>>;
22
23    /// The error type returned from a failed buffer provision.
24    #[cfg(not(feature = "defmt"))]
25    type ProvisionError: core::fmt::Debug;
26    /// The error type returned from a failed buffer provision.
27    #[cfg(feature = "defmt")]
28    type ProvisionError: core::fmt::Debug + defmt::Format;
29
30    /// If successful, returns contiguous memory with a size in bytes of the `len` argument.
31    ///
32    /// # Errors
33    ///
34    /// Returns a value of its associated error type if the buffer provision fails.
35    fn provide_buffer(&mut self, len: usize) -> Result<Self::Buffer, Self::ProvisionError>;
36}
37
38#[cfg(feature = "bump")]
39mod bump {
40    use core::{marker::PhantomData, slice};
41
42    use crate::buffer::BufferProvider;
43
44    /// Error returned when the [`BumpBuffer`]'s underlying buffer does not have enough unallocated space.
45    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
46    #[cfg_attr(feature = "defmt", derive(defmt::Format))]
47    pub struct InsufficientSpace;
48
49    /// Allocates memory from an underlying buffer by bumping up a pointer by the requested length.
50    ///
51    /// Can be reset when no references to buffer contents exist.
52    #[derive(Debug)]
53    #[cfg_attr(feature = "defmt", derive(defmt::Format))]
54    pub struct BumpBuffer<'a> {
55        ptr: *mut u8,
56        len: usize,
57        index: usize,
58        _phantom_data: PhantomData<&'a mut [u8]>,
59    }
60
61    impl<'a> BufferProvider<'a> for BumpBuffer<'a> {
62        type Buffer = &'a mut [u8];
63        type ProvisionError = InsufficientSpace;
64
65        /// Return the next `len` bytes from the buffer, advancing the internal tracking
66        /// index. Returns [`InsufficientSpace`] if there isn't enough room.
67        fn provide_buffer(&mut self, len: usize) -> Result<Self::Buffer, Self::ProvisionError> {
68            if self.remaining_len() < len {
69                Err(InsufficientSpace)
70            } else {
71                let start = self.index;
72
73                // Safety: we checked the bounds above meaning the resulting pointer
74                // is in the backing slice's range. This means the pointer arithmetic
75                // does not overflow.
76                // The pointer originates from the backing slice owned by this struct with the same lifetime.
77                let ptr = unsafe { self.ptr.add(start) };
78
79                self.index += len;
80
81                // Safety: the slice starts at the self.index offset which is not part of any previous reservation.
82                // Everything after this offset is not allocated and referenced.
83                // The lifetime is correct as the returned slice has the same lifetime as `Self` which is
84                // in turn has the lifetime of the backing slice.
85                let slice = unsafe { slice::from_raw_parts_mut(ptr, len) };
86
87                Ok(slice)
88            }
89        }
90    }
91
92    impl<'a> BumpBuffer<'a> {
93        /// Creates a new [`BumpBuffer`] with the provided slice as underlying buffer.
94        #[must_use]
95        pub fn new(slice: &'a mut [u8]) -> Self {
96            Self {
97                ptr: slice.as_mut_ptr(),
98                len: slice.len(),
99                index: 0,
100                _phantom_data: PhantomData,
101            }
102        }
103
104        /// Returns the remaining amount of unallocated bytes in the underlying buffer.
105        #[inline]
106        #[must_use]
107        pub fn remaining_len(&self) -> usize {
108            self.len - self.index
109        }
110
111        /// Invalidates all previous allocations by resetting the [`BumpBuffer`]'s internal tracking index into the underlying
112        /// buffer, allowing the underlying buffer to be reallocated down the line. After this, the bump buffer will allocate
113        /// starting with the first byte of the backing buffer again.
114        ///
115        /// # Safety
116        ///
117        /// This method is safe to call when no references to previously allocated slices or underlying buffer content exist.
118        /// The caller must ensure no more such references exist. In the context of the client, this is true when no more values
119        /// that have a lifetime tied to the used [`BumpBuffer`] instance exist.
120        ///
121        /// # Example
122        ///
123        /// ## Sound
124        ///
125        /// ```rust,ignore
126        /// use rust_mqtt::buffer::BumpBuffer;
127        /// use rust_mqtt::client::Client;
128        /// use rust_mqtt::client::info::ConnectInfo;
129        /// use rust_mqtt::client::options::ConnectOptions;
130        /// use tokio::net::TcpStream;
131        /// use embedded_io_adapters::tokio_1::FromTokio;
132        ///
133        /// let mut buffer = [0; 1024];
134        /// let mut buffer = BumpBuffer::new(&mut buffer);
135        /// let mut client: Client<'_, FromTokio<TcpStream>, _, 0, 1, 0, 0> = Client::new(&mut buffer);
136        ///
137        /// {
138        ///     // client_identifier lives inside buffer's backing buffer, so it prevents a reset call.
139        ///     let ConnectInfo { client_identifier, .. } = client.connect(todo!(), &ConnectOptions::new(), None).await.unwrap();
140        ///
141        /// }   // client_identifier is dropped here, now we can reset the buffer
142        ///
143        /// // Safety: client_identifier and all other previously returned values living in buffer's backing
144        /// // buffer don't exist anymore. No aliasing possible.
145        /// unsafe { client.buffer_mut().reset() };
146        ///
147        /// // The next allocation can happen safely here.
148        /// client.poll().await.unwrap();
149        /// ```
150        ///
151        /// ## Unsound
152        ///
153        /// ```rust,ignore
154        /// use rust_mqtt::buffer::BumpBuffer;
155        /// use rust_mqtt::client::Client;
156        /// use rust_mqtt::client::info::ConnectInfo;
157        /// use rust_mqtt::client::options::ConnectOptions;
158        /// use tokio::net::TcpStream;
159        /// use embedded_io_adapters::tokio_1::FromTokio;
160        ///
161        /// let mut buffer = [0; 1024];
162        /// let mut buffer = BumpBuffer::new(&mut buffer);
163        /// let mut client: Client<'_, FromTokio<TcpStream>, _, 0, 1, 0, 0> = Client::new(&mut buffer);
164        ///
165        /// // client_identifier lives inside buffer's backing buffer, so it prevents a reset call.
166        /// let ConnectInfo { client_identifier, .. } = client.connect(todo!(), &ConnectOptions::new(), None).await.unwrap();
167        ///
168        /// // (No) Safety: client_identifier still lives.
169        /// unsafe { client.buffer_mut().reset() };
170        ///
171        /// // The next allocation can happen here and cause an alias to client_identifier.
172        /// client.poll().await.unwrap();
173        ///
174        /// // client_identifier is still alive. It might have a different or even non-UTF-8 value now, scary...
175        /// println!("{:?}", client_identifier);
176        /// ```
177        #[inline]
178        pub unsafe fn reset(&mut self) {
179            self.index = 0;
180        }
181    }
182
183    fn _assert_covariant<'a, 'b: 'a>(x: BumpBuffer<'b>) -> BumpBuffer<'a> {
184        x
185    }
186
187    #[cfg(test)]
188    mod unit {
189        use tokio_test::{assert_err, assert_ok};
190
191        use super::*;
192
193        #[test]
194        fn provide_buffer_and_remaining_len() {
195            let mut backing = [0; 10];
196
197            {
198                let mut buf = BumpBuffer::new(&mut backing);
199
200                assert_eq!(buf.remaining_len(), 10);
201
202                let s1 = assert_ok!(buf.provide_buffer(4));
203                assert_eq!(s1.len(), 4);
204
205                s1.copy_from_slice(&[1, 2, 3, 4]);
206                assert_eq!(buf.remaining_len(), 6);
207
208                // take remaining 6 bytes
209                let s2 = assert_ok!(buf.provide_buffer(6));
210                assert_eq!(s2.len(), 6);
211
212                s2.copy_from_slice(&[5, 6, 7, 8, 9, 10]);
213                assert_eq!(buf.remaining_len(), 0);
214
215                assert_eq!(s1, [1, 2, 3, 4]);
216                assert_eq!(s2, [5, 6, 7, 8, 9, 10]);
217
218                let err = assert_err!(buf.provide_buffer(1));
219                assert_eq!(err, InsufficientSpace);
220            }
221
222            assert_eq!(backing, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
223        }
224
225        #[test]
226        fn reset_allows_reuse() {
227            let mut backing = [0; 6];
228
229            {
230                let mut buf = BumpBuffer::new(&mut backing);
231
232                let s1 = {
233                    let s1 = assert_ok!(buf.provide_buffer(3));
234                    s1.copy_from_slice(&[11, 12, 13]);
235
236                    s1.as_ptr()
237                };
238
239                // reset and take again from start
240                unsafe { buf.reset() }
241                let s2 = assert_ok!(buf.provide_buffer(3));
242
243                // Checking the slices for equality is UB because we have not upheld the rules of
244                // `BumpBuffer::reset` and it subsequently causes aliasing.
245                // assert_eq!(s1, s2);
246
247                assert_eq!(s1, s2.as_ptr());
248            }
249
250            assert_eq!(backing, [11, 12, 13, 0, 0, 0]);
251        }
252    }
253}
254
255#[cfg(feature = "alloc")]
256mod alloc {
257    use alloc::{boxed::Box, vec};
258    use core::convert::Infallible;
259
260    use crate::buffer::BufferProvider;
261
262    /// Allocates memory using the global allocator.
263    #[derive(Debug)]
264    #[cfg_attr(feature = "defmt", derive(defmt::Format))]
265    pub struct AllocBuffer;
266
267    impl<'a> BufferProvider<'a> for AllocBuffer {
268        type Buffer = Box<[u8]>;
269        type ProvisionError = Infallible;
270
271        /// Allocates `len` bytes on the heap
272        fn provide_buffer(&mut self, len: usize) -> Result<Self::Buffer, Self::ProvisionError> {
273            let buffer = vec![0; len].into_boxed_slice();
274
275            Ok(buffer)
276        }
277    }
278
279    #[cfg(test)]
280    mod unit {
281        use tokio_test::assert_ok;
282
283        use crate::buffer::{BufferProvider, alloc::AllocBuffer};
284
285        #[test]
286        fn provide_buffer() {
287            let mut alloc = AllocBuffer;
288
289            let buffer = alloc.provide_buffer(10);
290            let buffer = assert_ok!(buffer);
291            assert_eq!(10, buffer.len());
292        }
293    }
294}