wit_bindgen/rt/async_support/
abi_buffer.rs

1use crate::rt::async_support::StreamOps;
2use crate::rt::Cleanup;
3use std::alloc::Layout;
4use std::mem::{self, MaybeUninit};
5use std::ptr;
6use std::vec::Vec;
7
8/// A helper structure used with a stream to handle the canonical ABI
9/// representation of lists and track partial writes.
10///
11/// This structure is returned whenever a write to a stream completes. This
12/// keeps track of the original buffer used to perform a write (`Vec<T>`) and
13/// additionally tracks any partial writes. Writes can then be resumed with
14/// this buffer again or the partial write can be converted back to `Vec<T>` to
15/// get access to the remaining values.
16///
17/// This value is created through the [`StreamWrite`](super::StreamWrite)
18/// future's return value.
19pub struct AbiBuffer<O: StreamOps> {
20    rust_storage: Vec<MaybeUninit<O::Payload>>,
21    ops: O,
22    alloc: Option<Cleanup>,
23    cursor: usize,
24}
25
26impl<O: StreamOps> AbiBuffer<O> {
27    pub(crate) fn new(mut vec: Vec<O::Payload>, mut ops: O) -> AbiBuffer<O> {
28        // SAFETY: We're converting `Vec<T>` to `Vec<MaybeUninit<T>>`, which
29        // should be safe.
30        let rust_storage = unsafe {
31            let ptr = vec.as_mut_ptr();
32            let len = vec.len();
33            let cap = vec.capacity();
34            mem::forget(vec);
35            Vec::<MaybeUninit<O::Payload>>::from_raw_parts(ptr.cast(), len, cap)
36        };
37
38        // If `lower` is provided then the canonical ABI format is different
39        // from the native format, so all items are converted at this time.
40        //
41        // Note that this is probably pretty inefficient for "big" use cases
42        // but it's hoped that "big" use cases are using `u8` and therefore
43        // skip this entirely.
44        let alloc = if ops.native_abi_matches_canonical_abi() {
45            None
46        } else {
47            let elem_layout = ops.elem_layout();
48            let layout = Layout::from_size_align(
49                elem_layout.size() * rust_storage.len(),
50                elem_layout.align(),
51            )
52            .unwrap();
53            let (mut ptr, cleanup) = Cleanup::new(layout);
54            // SAFETY: All items in `rust_storage` are already initialized so
55            // it should be safe to read them and move ownership into the
56            // canonical ABI format.
57            unsafe {
58                for item in rust_storage.iter() {
59                    let item = item.assume_init_read();
60                    ops.lower(item, ptr);
61                    ptr = ptr.add(elem_layout.size());
62                }
63            }
64            cleanup
65        };
66        AbiBuffer {
67            rust_storage,
68            alloc,
69            ops,
70            cursor: 0,
71        }
72    }
73
74    /// Returns the canonical ABI pointer/length to pass off to a write
75    /// operation.
76    pub(crate) fn abi_ptr_and_len(&self) -> (*const u8, usize) {
77        // If there's no `lower` operation then it means that `T`'s layout is
78        // the same in the canonical ABI so it can be used as-is. In this
79        // situation the list would have been un-tampered with above.
80        if self.ops.native_abi_matches_canonical_abi() {
81            // SAFETY: this should be in-bounds, so it should be safe.
82            let ptr = unsafe { self.rust_storage.as_ptr().add(self.cursor).cast() };
83            let len = self.rust_storage.len() - self.cursor;
84            return (ptr, len.try_into().unwrap());
85        }
86
87        // Othereise when `lower` is present that means that `self.alloc` has
88        // the ABI pointer we should pass along.
89        let ptr = self
90            .alloc
91            .as_ref()
92            .map(|c| c.ptr.as_ptr())
93            .unwrap_or(ptr::null_mut());
94        (
95            // SAFETY: this should be in-bounds, so it should be safe.
96            unsafe { ptr.add(self.cursor * self.ops.elem_layout().size()) },
97            self.rust_storage.len() - self.cursor,
98        )
99    }
100
101    /// Converts this `AbiBuffer<T>` back into a `Vec<T>`
102    ///
103    /// This commit consumes this buffer and yields back unwritten values as a
104    /// `Vec<T>`. The remaining items in `Vec<T>` have not yet been written and
105    /// all written items have been removed from the front of the list.
106    ///
107    /// Note that the backing storage of the returned `Vec<T>` has not changed
108    /// from whe this buffer was created.
109    ///
110    /// Also note that this can be an expensive operation if a partial write
111    /// occurred as this will involve shifting items from the end of the vector
112    /// to the start of the vector.
113    pub fn into_vec(mut self) -> Vec<O::Payload> {
114        self.take_vec()
115    }
116
117    /// Returns the number of items remaining in this buffer.
118    pub fn remaining(&self) -> usize {
119        self.rust_storage.len() - self.cursor
120    }
121
122    /// Advances this buffer by `amt` items.
123    ///
124    /// This signals that `amt` items are no longer going to be yielded from
125    /// `abi_ptr_and_len`. Additionally this will perform any deallocation
126    /// necessary for the starting `amt` items in this list.
127    pub(crate) fn advance(&mut self, amt: usize) {
128        assert!(amt + self.cursor <= self.rust_storage.len());
129        if !self.ops.contains_lists() {
130            self.cursor += amt;
131            return;
132        }
133        let (mut ptr, len) = self.abi_ptr_and_len();
134        assert!(amt <= len);
135        for _ in 0..amt {
136            // SAFETY: we're managing the pointer passed to `dealloc_lists` and
137            // it was initialized with a `lower`, and then the pointer
138            // arithmetic should all be in-bounds.
139            unsafe {
140                self.ops.dealloc_lists(ptr.cast_mut());
141                ptr = ptr.add(self.ops.elem_layout().size());
142            }
143        }
144        self.cursor += amt;
145    }
146
147    fn take_vec(&mut self) -> Vec<O::Payload> {
148        // First, if necessary, convert remaining values within `self.alloc`
149        // back into `self.rust_storage`. This is necessary when a lift
150        // operation is available meaning that the representation of `T` is
151        // different in the canonical ABI.
152        //
153        // Note that when `lift` is provided then when this original
154        // `AbiBuffer` was created it moved ownership of all values from the
155        // original vector into the `alloc` value. This is the reverse
156        // operation, moving all the values back into the vector.
157        if !self.ops.native_abi_matches_canonical_abi() {
158            let (mut ptr, mut len) = self.abi_ptr_and_len();
159            // SAFETY: this should be safe as `lift` is operating on values that
160            // were initialized with a previous `lower`, and the pointer
161            // arithmetic here should all be in-bounds.
162            unsafe {
163                for dst in self.rust_storage[self.cursor..].iter_mut() {
164                    dst.write(self.ops.lift(ptr.cast_mut()));
165                    ptr = ptr.add(self.ops.elem_layout().size());
166                    len -= 1;
167                }
168                assert_eq!(len, 0);
169            }
170        }
171
172        // Next extract the rust storage and zero out this struct's fields.
173        // This is also the location where a "shift" happens to remove items
174        // from the beginning of the returned vector as those have already been
175        // transferred somewhere else.
176        let mut storage = mem::take(&mut self.rust_storage);
177        storage.drain(..self.cursor);
178        self.cursor = 0;
179        self.alloc = None;
180
181        // SAFETY: we're casting `Vec<MaybeUninit<T>>` here to `Vec<T>`. The
182        // elements were either always initialized (`lift` is `None`) or we just
183        // re-initialized them above from `self.alloc`.
184        unsafe {
185            let ptr = storage.as_mut_ptr();
186            let len = storage.len();
187            let cap = storage.capacity();
188            mem::forget(storage);
189            Vec::<O::Payload>::from_raw_parts(ptr.cast(), len, cap)
190        }
191    }
192}
193
194impl<O> Drop for AbiBuffer<O>
195where
196    O: StreamOps,
197{
198    fn drop(&mut self) {
199        let _ = self.take_vec();
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::rt::async_support::StreamVtable;
207    use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
208    use std::vec;
209
210    extern "C" fn cancel(_: u32) -> u32 {
211        todo!()
212    }
213    extern "C" fn drop(_: u32) {
214        todo!()
215    }
216    extern "C" fn new() -> u64 {
217        todo!()
218    }
219    extern "C" fn start_read(_: u32, _: *mut u8, _: usize) -> u32 {
220        todo!()
221    }
222    extern "C" fn start_write(_: u32, _: *const u8, _: usize) -> u32 {
223        todo!()
224    }
225
226    static BLANK: StreamVtable<u8> = StreamVtable {
227        cancel_read: cancel,
228        cancel_write: cancel,
229        drop_readable: drop,
230        drop_writable: drop,
231        dealloc_lists: None,
232        lift: None,
233        lower: None,
234        layout: unsafe { Layout::from_size_align_unchecked(1, 1) },
235        new,
236        start_read,
237        start_write,
238    };
239
240    #[test]
241    fn blank_advance_to_end() {
242        let mut buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
243        assert_eq!(buffer.remaining(), 4);
244        buffer.advance(1);
245        assert_eq!(buffer.remaining(), 3);
246        buffer.advance(2);
247        assert_eq!(buffer.remaining(), 1);
248        buffer.advance(1);
249        assert_eq!(buffer.remaining(), 0);
250        assert_eq!(buffer.into_vec(), []);
251    }
252
253    #[test]
254    fn blank_advance_partial() {
255        let buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
256        assert_eq!(buffer.into_vec(), [1, 2, 3, 4]);
257        let mut buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
258        buffer.advance(1);
259        assert_eq!(buffer.into_vec(), [2, 3, 4]);
260        let mut buffer = AbiBuffer::new(vec![1, 2, 3, 4], &BLANK);
261        buffer.advance(1);
262        buffer.advance(2);
263        assert_eq!(buffer.into_vec(), [4]);
264    }
265
266    #[test]
267    fn blank_ptr_eq() {
268        let mut buf = vec![1, 2, 3, 4];
269        let ptr = buf.as_mut_ptr();
270        let mut buffer = AbiBuffer::new(buf, &BLANK);
271        let (a, b) = buffer.abi_ptr_and_len();
272        assert_eq!(a, ptr);
273        assert_eq!(b, 4);
274        unsafe {
275            assert_eq!(std::slice::from_raw_parts(a, b), [1, 2, 3, 4]);
276        }
277
278        buffer.advance(1);
279        let (a, b) = buffer.abi_ptr_and_len();
280        assert_eq!(a, ptr.wrapping_add(1));
281        assert_eq!(b, 3);
282        unsafe {
283            assert_eq!(std::slice::from_raw_parts(a, b), [2, 3, 4]);
284        }
285
286        buffer.advance(2);
287        let (a, b) = buffer.abi_ptr_and_len();
288        assert_eq!(a, ptr.wrapping_add(3));
289        assert_eq!(b, 1);
290        unsafe {
291            assert_eq!(std::slice::from_raw_parts(a, b), [4]);
292        }
293
294        let ret = buffer.into_vec();
295        assert_eq!(ret, [4]);
296        assert_eq!(ret.as_ptr(), ptr);
297    }
298
299    #[derive(PartialEq, Eq, Debug)]
300    struct B(u8);
301
302    static OP: StreamVtable<B> = StreamVtable {
303        cancel_read: cancel,
304        cancel_write: cancel,
305        drop_readable: drop,
306        drop_writable: drop,
307        dealloc_lists: Some(|_ptr| {}),
308        lift: Some(|ptr| unsafe { B(*ptr - 1) }),
309        lower: Some(|b, ptr| unsafe {
310            *ptr = b.0 + 1;
311        }),
312        layout: unsafe { Layout::from_size_align_unchecked(1, 1) },
313        new,
314        start_read,
315        start_write,
316    };
317
318    #[test]
319    fn op_advance_to_end() {
320        let mut buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
321        assert_eq!(buffer.remaining(), 4);
322        buffer.advance(1);
323        assert_eq!(buffer.remaining(), 3);
324        buffer.advance(2);
325        assert_eq!(buffer.remaining(), 1);
326        buffer.advance(1);
327        assert_eq!(buffer.remaining(), 0);
328        assert_eq!(buffer.into_vec(), []);
329    }
330
331    #[test]
332    fn op_advance_partial() {
333        let buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
334        assert_eq!(buffer.into_vec(), [B(1), B(2), B(3), B(4)]);
335        let mut buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
336        buffer.advance(1);
337        assert_eq!(buffer.into_vec(), [B(2), B(3), B(4)]);
338        let mut buffer = AbiBuffer::new(vec![B(1), B(2), B(3), B(4)], &OP);
339        buffer.advance(1);
340        buffer.advance(2);
341        assert_eq!(buffer.into_vec(), [B(4)]);
342    }
343
344    #[test]
345    fn op_ptrs() {
346        let mut buf = vec![B(1), B(2), B(3), B(4)];
347        let ptr = buf.as_mut_ptr().cast::<u8>();
348        let mut buffer = AbiBuffer::new(buf, &OP);
349        let (a, b) = buffer.abi_ptr_and_len();
350        let base = a;
351        assert_ne!(a, ptr);
352        assert_eq!(b, 4);
353        unsafe {
354            assert_eq!(std::slice::from_raw_parts(a, b), [2, 3, 4, 5]);
355        }
356
357        buffer.advance(1);
358        let (a, b) = buffer.abi_ptr_and_len();
359        assert_ne!(a, ptr.wrapping_add(1));
360        assert_eq!(a, base.wrapping_add(1));
361        assert_eq!(b, 3);
362        unsafe {
363            assert_eq!(std::slice::from_raw_parts(a, b), [3, 4, 5]);
364        }
365
366        buffer.advance(2);
367        let (a, b) = buffer.abi_ptr_and_len();
368        assert_ne!(a, ptr.wrapping_add(3));
369        assert_eq!(a, base.wrapping_add(3));
370        assert_eq!(b, 1);
371        unsafe {
372            assert_eq!(std::slice::from_raw_parts(a, b), [5]);
373        }
374
375        let ret = buffer.into_vec();
376        assert_eq!(ret, [B(4)]);
377        assert_eq!(ret.as_ptr(), ptr.cast());
378    }
379
380    #[test]
381    fn dealloc_lists() {
382        static DEALLOCS: AtomicUsize = AtomicUsize::new(0);
383        static OP: StreamVtable<B> = StreamVtable {
384            cancel_read: cancel,
385            cancel_write: cancel,
386            drop_readable: drop,
387            drop_writable: drop,
388            dealloc_lists: Some(|ptr| {
389                let prev = DEALLOCS.fetch_add(1, Relaxed);
390                assert_eq!(unsafe { usize::from(*ptr) }, prev + 1);
391            }),
392            lift: Some(|ptr| unsafe { B(*ptr) }),
393            lower: Some(|b, ptr| unsafe {
394                *ptr = b.0;
395            }),
396            layout: unsafe { Layout::from_size_align_unchecked(1, 1) },
397            new,
398            start_read,
399            start_write,
400        };
401
402        assert_eq!(DEALLOCS.load(Relaxed), 0);
403        let buf = vec![B(1), B(2), B(3), B(4)];
404        let mut buffer = AbiBuffer::new(buf, &OP);
405        assert_eq!(DEALLOCS.load(Relaxed), 0);
406        buffer.abi_ptr_and_len();
407        assert_eq!(DEALLOCS.load(Relaxed), 0);
408
409        buffer.advance(1);
410        assert_eq!(DEALLOCS.load(Relaxed), 1);
411        buffer.abi_ptr_and_len();
412        assert_eq!(DEALLOCS.load(Relaxed), 1);
413        buffer.advance(2);
414        assert_eq!(DEALLOCS.load(Relaxed), 3);
415        buffer.abi_ptr_and_len();
416        assert_eq!(DEALLOCS.load(Relaxed), 3);
417        buffer.into_vec();
418        assert_eq!(DEALLOCS.load(Relaxed), 3);
419    }
420}