wit_bindgen/rt/async_support/
abi_buffer.rs

1use crate::rt::async_support::StreamVtable;
2use crate::rt::Cleanup;
3use std::alloc::Layout;
4use std::mem::{self, MaybeUninit};
5use std::ptr;
6use std::vec::Vec;
7
8/// A helper structure used with a stream to handle the canonical ABI
9/// representation of lists and track partial writes.
10///
11/// This structure is returned whenever a write to a stream completes. This
12/// keeps track of the original buffer used to perform a write (`Vec<T>`) and
13/// additionally tracks any partial writes. Writes can then be resumed with
14/// this buffer again or the partial write can be converted back to `Vec<T>` to
15/// get access to the remaining values.
16///
17/// This value is created through the [`StreamWrite`](super::StreamWrite)
18/// future's return value.
19pub struct AbiBuffer<T: 'static> {
20    rust_storage: Vec<MaybeUninit<T>>,
21    vtable: &'static StreamVtable<T>,
22    alloc: Option<Cleanup>,
23    cursor: usize,
24}
25
26impl<T: 'static> AbiBuffer<T> {
27    pub(crate) fn new(mut vec: Vec<T>, vtable: &'static StreamVtable<T>) -> AbiBuffer<T> {
28        assert_eq!(vtable.lower.is_some(), vtable.lift.is_some());
29
30        // SAFETY: We're converting `Vec<T>` to `Vec<MaybeUninit<T>>`, which
31        // should be safe.
32        let rust_storage = unsafe {
33            let ptr = vec.as_mut_ptr();
34            let len = vec.len();
35            let cap = vec.capacity();
36            mem::forget(vec);
37            Vec::<MaybeUninit<T>>::from_raw_parts(ptr.cast(), len, cap)
38        };
39
40        // If `lower` is provided then the canonical ABI format is different
41        // from the native format, so all items are converted at this time.
42        //
43        // Note that this is probably pretty inefficient for "big" use cases
44        // but it's hoped that "big" use cases are using `u8` and therefore
45        // skip this entirely.
46        let alloc = vtable.lower.and_then(|lower| {
47            let layout = Layout::from_size_align(
48                vtable.layout.size() * rust_storage.len(),
49                vtable.layout.align(),
50            )
51            .unwrap();
52            let (mut ptr, cleanup) = Cleanup::new(layout);
53            let cleanup = cleanup?;
54            // SAFETY: All items in `rust_storage` are already initialized so
55            // it should be safe to read them and move ownership into the
56            // canonical ABI format.
57            unsafe {
58                for item in rust_storage.iter() {
59                    let item = item.assume_init_read();
60                    lower(item, ptr);
61                    ptr = ptr.add(vtable.layout.size());
62                }
63            }
64
65            Some(cleanup)
66        });
67        AbiBuffer {
68            rust_storage,
69            alloc,
70            vtable,
71            cursor: 0,
72        }
73    }
74
75    /// Returns the canonical ABI pointer/length to pass off to a write
76    /// operation.
77    pub(crate) fn abi_ptr_and_len(&self) -> (*const u8, usize) {
78        // If there's no `lower` operation then it means that `T`'s layout is
79        // the same in the canonical ABI so it can be used as-is. In this
80        // situation the list would have been un-tampered with above.
81        if self.vtable.lower.is_none() {
82            // SAFETY: this should be in-bounds, so it should be safe.
83            let ptr = unsafe { self.rust_storage.as_ptr().add(self.cursor).cast() };
84            let len = self.rust_storage.len() - self.cursor;
85            return (ptr, len.try_into().unwrap());
86        }
87
88        // Othereise when `lower` is present that means that `self.alloc` has
89        // the ABI pointer we should pass along.
90        let ptr = self
91            .alloc
92            .as_ref()
93            .map(|c| c.ptr.as_ptr())
94            .unwrap_or(ptr::null_mut());
95        (
96            // SAFETY: this should be in-bounds, so it should be safe.
97            unsafe { ptr.add(self.cursor * self.vtable.layout.size()) },
98            self.rust_storage.len() - self.cursor,
99        )
100    }
101
102    /// Converts this `AbiBuffer<T>` back into a `Vec<T>`
103    ///
104    /// This commit consumes this buffer and yields back unwritten values as a
105    /// `Vec<T>`. The remaining items in `Vec<T>` have not yet been written and
106    /// all written items have been removed from the front of the list.
107    ///
108    /// Note that the backing storage of the returned `Vec<T>` has not changed
109    /// from whe this buffer was created.
110    ///
111    /// Also note that this can be an expensive operation if a partial write
112    /// occurred as this will involve shifting items from the end of the vector
113    /// to the start of the vector.
114    pub fn into_vec(mut self) -> Vec<T> {
115        self.take_vec()
116    }
117
118    /// Returns the number of items remaining in this buffer.
119    pub fn remaining(&self) -> usize {
120        self.rust_storage.len() - self.cursor
121    }
122
123    /// Advances this buffer by `amt` items.
124    ///
125    /// This signals that `amt` items are no longer going to be yielded from
126    /// `abi_ptr_and_len`. Additionally this will perform any deallocation
127    /// necessary for the starting `amt` items in this list.
128    pub(crate) fn advance(&mut self, amt: usize) {
129        assert!(amt + self.cursor <= self.rust_storage.len());
130        let Some(dealloc_lists) = self.vtable.dealloc_lists else {
131            self.cursor += amt;
132            return;
133        };
134        let (mut ptr, len) = self.abi_ptr_and_len();
135        assert!(amt <= len);
136        for _ in 0..amt {
137            // SAFETY: we're managing the pointer passed to `dealloc_lists` and
138            // it was initialized with a `lower`, and then the pointer
139            // arithmetic should all be in-bounds.
140            unsafe {
141                dealloc_lists(ptr.cast_mut());
142                ptr = ptr.add(self.vtable.layout.size());
143            }
144        }
145        self.cursor += amt;
146    }
147
148    fn take_vec(&mut self) -> Vec<T> {
149        // First, if necessary, convert remaining values within `self.alloc`
150        // back into `self.rust_storage`. This is necessary when a lift
151        // operation is available meaning that the representation of `T` is
152        // different in the canonical ABI.
153        //
154        // Note that when `lift` is provided then when this original
155        // `AbiBuffer` was created it moved ownership of all values from the
156        // original vector into the `alloc` value. This is the reverse
157        // operation, moving all the values back into the vector.
158        if let Some(lift) = self.vtable.lift {
159            let (mut ptr, mut len) = self.abi_ptr_and_len();
160            // SAFETY: this should be safe as `lift` is operating on values that
161            // were initialized with a previous `lower`, and the pointer
162            // arithmetic here should all be in-bounds.
163            unsafe {
164                for dst in self.rust_storage[self.cursor..].iter_mut() {
165                    dst.write(lift(ptr.cast_mut()));
166                    ptr = ptr.add(self.vtable.layout.size());
167                    len -= 1;
168                }
169                assert_eq!(len, 0);
170            }
171        }
172
173        // Next extract the rust storage and zero out this struct's fields.
174        // This is also the location where a "shift" happens to remove items
175        // from the beginning of the returned vector as those have already been
176        // transferred somewhere else.
177        let mut storage = mem::take(&mut self.rust_storage);
178        storage.drain(..self.cursor);
179        self.cursor = 0;
180        self.alloc = None;
181
182        // SAFETY: we're casting `Vec<MaybeUninit<T>>` here to `Vec<T>`. The
183        // elements were either always initialized (`lift` is `None`) or we just
184        // re-initialized them above from `self.alloc`.
185        unsafe {
186            let ptr = storage.as_mut_ptr();
187            let len = storage.len();
188            let cap = storage.capacity();
189            mem::forget(storage);
190            Vec::<T>::from_raw_parts(ptr.cast(), len, cap)
191        }
192    }
193}
194
195impl<T> Drop for AbiBuffer<T> {
196    fn drop(&mut self) {
197        let _ = self.take_vec();
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
205    use std::vec;
206
207    extern "C" fn cancel(_: u32) -> u32 {
208        todo!()
209    }
210    extern "C" fn drop(_: u32) {
211        todo!()
212    }
213    extern "C" fn new() -> u64 {
214        todo!()
215    }
216    extern "C" fn start_read(_: u32, _: *mut u8, _: usize) -> u32 {
217        todo!()
218    }
219    extern "C" fn start_write(_: u32, _: *const u8, _: usize) -> u32 {
220        todo!()
221    }
222
223    static BLANK: StreamVtable<u8> = StreamVtable {
224        cancel_read: cancel,
225        cancel_write: cancel,
226        drop_readable: drop,
227        drop_writable: drop,
228        dealloc_lists: None,
229        lift: None,
230        lower: None,
231        layout: unsafe { Layout::from_size_align_unchecked(1, 1) },
232        new,
233        start_read,
234        start_write,
235    };
236
237    #[test]
238    fn blank_advance_to_end() {
239        let mut buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
240        assert_eq!(buffer.remaining(), 4);
241        buffer.advance(1);
242        assert_eq!(buffer.remaining(), 3);
243        buffer.advance(2);
244        assert_eq!(buffer.remaining(), 1);
245        buffer.advance(1);
246        assert_eq!(buffer.remaining(), 0);
247        assert_eq!(buffer.into_vec(), []);
248    }
249
250    #[test]
251    fn blank_advance_partial() {
252        let buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
253        assert_eq!(buffer.into_vec(), [1, 2, 3, 4]);
254        let mut buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
255        buffer.advance(1);
256        assert_eq!(buffer.into_vec(), [2, 3, 4]);
257        let mut buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
258        buffer.advance(1);
259        buffer.advance(2);
260        assert_eq!(buffer.into_vec(), [4]);
261    }
262
263    #[test]
264    fn blank_ptr_eq() {
265        let mut buf = vec![1, 2, 3, 4];
266        let ptr = buf.as_mut_ptr();
267        let mut buffer = AbiBuffer::new(buf, &BLANK);
268        let (a, b) = buffer.abi_ptr_and_len();
269        assert_eq!(a, ptr);
270        assert_eq!(b, 4);
271        unsafe {
272            assert_eq!(std::slice::from_raw_parts(a, b), [1, 2, 3, 4]);
273        }
274
275        buffer.advance(1);
276        let (a, b) = buffer.abi_ptr_and_len();
277        assert_eq!(a, ptr.wrapping_add(1));
278        assert_eq!(b, 3);
279        unsafe {
280            assert_eq!(std::slice::from_raw_parts(a, b), [2, 3, 4]);
281        }
282
283        buffer.advance(2);
284        let (a, b) = buffer.abi_ptr_and_len();
285        assert_eq!(a, ptr.wrapping_add(3));
286        assert_eq!(b, 1);
287        unsafe {
288            assert_eq!(std::slice::from_raw_parts(a, b), [4]);
289        }
290
291        let ret = buffer.into_vec();
292        assert_eq!(ret, [4]);
293        assert_eq!(ret.as_ptr(), ptr);
294    }
295
296    #[derive(PartialEq, Eq, Debug)]
297    struct B(u8);
298
299    static OP: StreamVtable<B> = StreamVtable {
300        cancel_read: cancel,
301        cancel_write: cancel,
302        drop_readable: drop,
303        drop_writable: drop,
304        dealloc_lists: Some(|_ptr| {}),
305        lift: Some(|ptr| unsafe { B(*ptr - 1) }),
306        lower: Some(|b, ptr| unsafe {
307            *ptr = b.0 + 1;
308        }),
309        layout: unsafe { Layout::from_size_align_unchecked(1, 1) },
310        new,
311        start_read,
312        start_write,
313    };
314
315    #[test]
316    fn op_advance_to_end() {
317        let mut buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
318        assert_eq!(buffer.remaining(), 4);
319        buffer.advance(1);
320        assert_eq!(buffer.remaining(), 3);
321        buffer.advance(2);
322        assert_eq!(buffer.remaining(), 1);
323        buffer.advance(1);
324        assert_eq!(buffer.remaining(), 0);
325        assert_eq!(buffer.into_vec(), []);
326    }
327
328    #[test]
329    fn op_advance_partial() {
330        let buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
331        assert_eq!(buffer.into_vec(), [B(1), B(2), B(3), B(4)]);
332        let mut buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
333        buffer.advance(1);
334        assert_eq!(buffer.into_vec(), [B(2), B(3), B(4)]);
335        let mut buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
336        buffer.advance(1);
337        buffer.advance(2);
338        assert_eq!(buffer.into_vec(), [B(4)]);
339    }
340
341    #[test]
342    fn op_ptrs() {
343        let mut buf = vec![B(1), B(2), B(3), B(4)];
344        let ptr = buf.as_mut_ptr().cast::<u8>();
345        let mut buffer = AbiBuffer::new(buf, &OP);
346        let (a, b) = buffer.abi_ptr_and_len();
347        let base = a;
348        assert_ne!(a, ptr);
349        assert_eq!(b, 4);
350        unsafe {
351            assert_eq!(std::slice::from_raw_parts(a, b), [2, 3, 4, 5]);
352        }
353
354        buffer.advance(1);
355        let (a, b) = buffer.abi_ptr_and_len();
356        assert_ne!(a, ptr.wrapping_add(1));
357        assert_eq!(a, base.wrapping_add(1));
358        assert_eq!(b, 3);
359        unsafe {
360            assert_eq!(std::slice::from_raw_parts(a, b), [3, 4, 5]);
361        }
362
363        buffer.advance(2);
364        let (a, b) = buffer.abi_ptr_and_len();
365        assert_ne!(a, ptr.wrapping_add(3));
366        assert_eq!(a, base.wrapping_add(3));
367        assert_eq!(b, 1);
368        unsafe {
369            assert_eq!(std::slice::from_raw_parts(a, b), [5]);
370        }
371
372        let ret = buffer.into_vec();
373        assert_eq!(ret, [B(4)]);
374        assert_eq!(ret.as_ptr(), ptr.cast());
375    }
376
377    #[test]
378    fn dealloc_lists() {
379        static DEALLOCS: AtomicUsize = AtomicUsize::new(0);
380        static OP: StreamVtable<B> = StreamVtable {
381            cancel_read: cancel,
382            cancel_write: cancel,
383            drop_readable: drop,
384            drop_writable: drop,
385            dealloc_lists: Some(|ptr| {
386                let prev = DEALLOCS.fetch_add(1, Relaxed);
387                assert_eq!(unsafe { usize::from(*ptr) }, prev + 1);
388            }),
389            lift: Some(|ptr| unsafe { B(*ptr) }),
390            lower: Some(|b, ptr| unsafe {
391                *ptr = b.0;
392            }),
393            layout: unsafe { Layout::from_size_align_unchecked(1, 1) },
394            new,
395            start_read,
396            start_write,
397        };
398
399        assert_eq!(DEALLOCS.load(Relaxed), 0);
400        let buf = vec![B(1), B(2), B(3), B(4)];
401        let mut buffer = AbiBuffer::new(buf, &OP);
402        assert_eq!(DEALLOCS.load(Relaxed), 0);
403        buffer.abi_ptr_and_len();
404        assert_eq!(DEALLOCS.load(Relaxed), 0);
405
406        buffer.advance(1);
407        assert_eq!(DEALLOCS.load(Relaxed), 1);
408        buffer.abi_ptr_and_len();
409        assert_eq!(DEALLOCS.load(Relaxed), 1);
410        buffer.advance(2);
411        assert_eq!(DEALLOCS.load(Relaxed), 3);
412        buffer.abi_ptr_and_len();
413        assert_eq!(DEALLOCS.load(Relaxed), 3);
414        buffer.into_vec();
415        assert_eq!(DEALLOCS.load(Relaxed), 3);
416    }
417}