1use proptest::prelude::*;
2use std::cell::UnsafeCell;
3use std::marker;
4use wiggle::GuestMemory;
5
6#[derive(Debug, Clone)]
7pub struct MemAreas(Vec<MemArea>);
8impl MemAreas {
9    pub fn new() -> Self {
10        MemAreas(Vec::new())
11    }
12    pub fn insert(&mut self, a: MemArea) {
13        match self.0.binary_search(&a) {
15            Ok(loc) => self.0.insert(loc, a),
17            Err(loc) => self.0.insert(loc, a),
19        }
20    }
21    pub fn iter(&self) -> impl Iterator<Item = &MemArea> {
22        self.0.iter()
23    }
24}
25
26impl<R> From<R> for MemAreas
27where
28    R: AsRef<[MemArea]>,
29{
30    fn from(ms: R) -> MemAreas {
31        let mut out = MemAreas::new();
32        for m in ms.as_ref().into_iter() {
33            out.insert(*m);
34        }
35        out
36    }
37}
38
39impl Into<Vec<MemArea>> for MemAreas {
40    fn into(self) -> Vec<MemArea> {
41        self.0.clone()
42    }
43}
44
45#[repr(align(4096))]
46pub struct HostMemory {
47    buffer: UnsafeCell<[u8; 4096]>,
48}
49impl HostMemory {
50    pub fn new() -> Self {
51        HostMemory {
52            buffer: UnsafeCell::new([0; 4096]),
53        }
54    }
55
56    pub fn mem_area_strat(align: u32) -> BoxedStrategy<MemArea> {
57        prop::num::u32::ANY
58            .prop_filter_map("needs to fit in memory", move |p| {
59                let p_aligned = p - (p % align); let ptr = p_aligned % 4096; if ptr + align < 4096 {
62                    Some(MemArea { ptr, len: align })
63                } else {
64                    None
65                }
66            })
67            .boxed()
68    }
69
70    pub fn invert(regions: &MemAreas) -> MemAreas {
73        let mut out = MemAreas::new();
74        let mut start = 0;
75        for r in regions.iter() {
76            let len = r.ptr - start;
77            if len > 0 {
78                out.insert(MemArea {
79                    ptr: start,
80                    len: r.ptr - start,
81                });
82            }
83            start = r.ptr + r.len;
84        }
85        if start < 4096 {
86            out.insert(MemArea {
87                ptr: start,
88                len: 4096 - start,
89            });
90        }
91        out
92    }
93
94    pub fn byte_slice_strat(size: u32, exclude: &MemAreas) -> BoxedStrategy<MemArea> {
95        let available: Vec<MemArea> = Self::invert(exclude)
96            .iter()
97            .flat_map(|a| a.inside(size))
98            .collect();
99
100        Just(available)
101            .prop_filter("available memory for allocation", |a| !a.is_empty())
102            .prop_flat_map(|a| prop::sample::select(a))
103            .boxed()
104    }
105}
106
107unsafe impl GuestMemory for HostMemory {
108    fn base(&self) -> (*mut u8, u32) {
109        unsafe {
110            let ptr = self.buffer.get();
111            ((*ptr).as_mut_ptr(), (*ptr).len() as u32)
112        }
113    }
114}
115
116#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
117pub struct MemArea {
118    pub ptr: u32,
119    pub len: u32,
120}
121
122impl MemArea {
123    pub fn overlapping(&self, b: Self) -> bool {
128        let a_range = std::ops::Range {
130            start: self.ptr,
131            end: self.ptr + self.len, };
133        let b_range = std::ops::Range {
135            start: b.ptr,
136            end: b.ptr + b.len,
137        };
138        for b_elem in b_range.clone() {
140            if a_range.contains(&b_elem) {
141                return true;
142            }
143        }
144        for a_elem in a_range {
146            if b_range.contains(&a_elem) {
147                return true;
148            }
149        }
150        return false;
151    }
152    pub fn non_overlapping_set<M>(areas: M) -> bool
153    where
154        M: Into<MemAreas>,
155    {
156        let areas = areas.into();
157        for (aix, a) in areas.iter().enumerate() {
158            for (bix, b) in areas.iter().enumerate() {
159                if aix != bix {
160                    if a.overlapping(*b) {
162                        return false;
163                    }
164                }
165            }
166        }
167        return true;
168    }
169
170    fn inside(&self, len: u32) -> impl Iterator<Item = MemArea> {
172        let end: i64 = self.len as i64 - len as i64;
173        let start = self.ptr;
174        (0..end).into_iter().map(move |v| MemArea {
175            ptr: start + v as u32,
176            len,
177        })
178    }
179}
180
181#[cfg(test)]
182mod test {
183    use super::*;
184    #[test]
185    fn hostmemory_is_aligned() {
186        let h = HostMemory::new();
187        assert_eq!(h.base().0 as usize % 4096, 0);
188        let h = Box::new(h);
189        assert_eq!(h.base().0 as usize % 4096, 0);
190    }
191
192    #[test]
193    fn invert() {
194        fn invert_equality(input: &[MemArea], expected: &[MemArea]) {
195            let input: MemAreas = input.into();
196            let inverted: Vec<MemArea> = HostMemory::invert(&input).into();
197            assert_eq!(expected, inverted.as_slice());
198        }
199
200        invert_equality(&[], &[MemArea { ptr: 0, len: 4096 }]);
201        invert_equality(
202            &[MemArea { ptr: 0, len: 1 }],
203            &[MemArea { ptr: 1, len: 4095 }],
204        );
205
206        invert_equality(
207            &[MemArea { ptr: 1, len: 1 }],
208            &[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 2, len: 4094 }],
209        );
210
211        invert_equality(
212            &[MemArea { ptr: 1, len: 4095 }],
213            &[MemArea { ptr: 0, len: 1 }],
214        );
215
216        invert_equality(
217            &[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 1, len: 4095 }],
218            &[],
219        );
220
221        invert_equality(
222            &[MemArea { ptr: 1, len: 2 }, MemArea { ptr: 4, len: 1 }],
223            &[
224                MemArea { ptr: 0, len: 1 },
225                MemArea { ptr: 3, len: 1 },
226                MemArea { ptr: 5, len: 4091 },
227            ],
228        );
229    }
230
231    fn set_of_slices_strat(
232        s1: u32,
233        s2: u32,
234        s3: u32,
235    ) -> BoxedStrategy<(MemArea, MemArea, MemArea)> {
236        HostMemory::byte_slice_strat(s1, &MemAreas::new())
237            .prop_flat_map(move |a1| {
238                (
239                    Just(a1),
240                    HostMemory::byte_slice_strat(s2, &MemAreas::from(&[a1])),
241                )
242            })
243            .prop_flat_map(move |(a1, a2)| {
244                (
245                    Just(a1),
246                    Just(a2),
247                    HostMemory::byte_slice_strat(s3, &MemAreas::from(&[a1, a2])),
248                )
249            })
250            .boxed()
251    }
252
253    #[test]
254    fn trivial_inside() {
255        let a = MemArea { ptr: 24, len: 4072 };
256        let interior = a.inside(24).collect::<Vec<_>>();
257
258        assert!(interior.len() > 0);
259    }
260
261    proptest! {
262        #[test]
263        fn inside(r in HostMemory::mem_area_strat(123)) {
265            let set_of_r = MemAreas::from(&[r]);
266            let exterior = HostMemory::invert(&set_of_r);
268            let interior = r.inside(22);
270            for i in interior {
271                assert!(r.overlapping(i));
273                assert!(i.ptr >= r.ptr);
275                assert!(r.ptr + r.len >= i.ptr + i.len);
276                let mut all = exterior.clone();
278                all.insert(i);
279                assert!(MemArea::non_overlapping_set(all));
280            }
281        }
282
283        #[test]
284        fn byte_slices((s1, s2, s3) in set_of_slices_strat(12, 34, 56)) {
285            let all = MemAreas::from(&[s1, s2, s3]);
286            assert!(MemArea::non_overlapping_set(all));
287        }
288    }
289}
290
291use std::cell::RefCell;
292use wiggle::GuestError;
293
294pub struct WasiCtx<'a> {
297    pub guest_errors: RefCell<Vec<GuestError>>,
298    lifetime: marker::PhantomData<&'a ()>,
299}
300
301impl<'a> WasiCtx<'a> {
302    pub fn new() -> Self {
303        Self {
304            guest_errors: RefCell::new(vec![]),
305            lifetime: marker::PhantomData,
306        }
307    }
308}
309
310#[macro_export]
315macro_rules! impl_errno {
316    ( $errno:ty ) => {
317        impl<'a> wiggle::GuestErrorType<'a> for $errno {
318            type Context = WasiCtx<'a>;
319            fn success() -> $errno {
320                <$errno>::Ok
321            }
322            fn from_error(e: GuestError, ctx: &WasiCtx) -> $errno {
323                eprintln!("GUEST ERROR: {:?}", e);
324                ctx.guest_errors.borrow_mut().push(e);
325                types::Errno::InvalidArg
326            }
327        }
328    };
329}