wapo_env/
args_stack.rs

1use crate::{IntPtr, IntRet, OcallError, Result};
2use log::Level;
3
4const OCALL_N_ARGS: usize = 4;
5
6pub(crate) struct StackedArgs<Args> {
7    args: Args,
8}
9
10impl StackedArgs<()> {
11    pub(crate) const fn empty() -> Self {
12        StackedArgs { args: () }
13    }
14}
15
16impl<A: Nargs> StackedArgs<A> {
17    pub(crate) fn load(mut raw: &[IntPtr]) -> Option<Self> {
18        Some(check_args_length(StackedArgs {
19            args: Nargs::load(&mut raw)?,
20        }))
21    }
22
23    pub(crate) fn dump(self) -> [IntPtr; OCALL_N_ARGS] {
24        let mut ret = [Default::default(); OCALL_N_ARGS];
25        let data = check_args_length(self).args.dump();
26        ret[..data.len()].copy_from_slice(&data);
27        ret
28    }
29}
30
31impl<A, B> StackedArgs<(A, B)> {
32    fn pop(self) -> (A, StackedArgs<B>) {
33        let (a, args) = self.args;
34        (a, StackedArgs { args })
35    }
36}
37
38impl<B> StackedArgs<B> {
39    fn push<A>(self, arg: A) -> StackedArgs<(A, B)> {
40        StackedArgs {
41            args: (arg, self.args),
42        }
43    }
44}
45
46pub(crate) trait Nargs {
47    const N_ARGS: usize;
48    fn load(buf: &mut &[IntPtr]) -> Option<Self>
49    where
50        Self: Sized;
51
52    // Since #![feature(generic_const_exprs)] is not yet stable, we use OCALL_N_ARGS instead of
53    // Self::N_ARGS
54    fn dump(self) -> [IntPtr; OCALL_N_ARGS];
55}
56
57impl Nargs for () {
58    const N_ARGS: usize = 0;
59    fn load(_buf: &mut &[IntPtr]) -> Option<Self> {
60        Some(())
61    }
62    fn dump(self) -> [IntPtr; OCALL_N_ARGS] {
63        Default::default()
64    }
65}
66
67impl Nargs for IntPtr {
68    const N_ARGS: usize = 1;
69    fn load(buf: &mut &[IntPtr]) -> Option<Self> {
70        let me = *buf.first()?;
71        *buf = &buf[1..];
72        Some(me)
73    }
74
75    fn dump(self) -> [IntPtr; OCALL_N_ARGS] {
76        let mut ret = [0; OCALL_N_ARGS];
77        ret[0] = self;
78        ret
79    }
80}
81
82impl<A, B> Nargs for (A, B)
83where
84    A: Nargs,
85    B: Nargs,
86{
87    const N_ARGS: usize = A::N_ARGS + B::N_ARGS;
88
89    fn load(buf: &mut &[IntPtr]) -> Option<Self> {
90        let b = B::load(buf)?;
91        let a = A::load(buf)?;
92        Some((a, b))
93    }
94
95    fn dump(self) -> [IntPtr; OCALL_N_ARGS] {
96        let (a, b) = self;
97        let mut buf = [IntPtr::default(); OCALL_N_ARGS];
98        buf[0..B::N_ARGS].copy_from_slice(&b.dump()[0..B::N_ARGS]);
99        buf[B::N_ARGS..Self::N_ARGS].copy_from_slice(&a.dump()[..A::N_ARGS]);
100        buf
101    }
102}
103
104// Since the const evaluation of Rust is not powerful enough yet, we use this trick to statically
105// check the argument types encode output do not exceed the maximum number of arguments.
106pub(crate) trait NotTooManyArgs {
107    const TOO_MANY_ARGUMENTS: ();
108}
109impl<T: Nargs> NotTooManyArgs for T {
110    const TOO_MANY_ARGUMENTS: () = [()][(Self::N_ARGS > OCALL_N_ARGS) as usize];
111}
112
113pub(crate) fn check_args_length<T: Nargs + NotTooManyArgs>(v: StackedArgs<T>) -> StackedArgs<T> {
114    #[allow(clippy::let_unit_value)]
115    let _ = T::TOO_MANY_ARGUMENTS;
116    v
117}
118
119pub(crate) trait I32Convertible {
120    fn to_i32(&self) -> i32;
121    fn from_i32(i: i32) -> Result<Self>
122    where
123        Self: Sized;
124}
125
126pub(crate) trait ArgEncode {
127    type Encoded;
128
129    fn encode_arg<A>(self, stack: StackedArgs<A>) -> StackedArgs<(Self::Encoded, A)>;
130}
131
132/// Trait for types that can be encoded to a return value of a ocall.
133pub trait RetEncode {
134    /// Encode the ocall return value into a IntRet
135    fn encode_ret(self) -> IntRet;
136}
137
138pub(crate) trait RetDecode {
139    fn decode_ret(encoded: IntRet) -> Self
140    where
141        Self: Sized;
142}
143
144impl ArgEncode for &[u8] {
145    type Encoded = (IntPtr, IntPtr);
146
147    fn encode_arg<A>(self, stack: StackedArgs<A>) -> StackedArgs<(Self::Encoded, A)> {
148        let ptr = self.as_ptr() as IntPtr;
149        let len = self.len() as IntPtr;
150        stack.push((len, ptr))
151    }
152}
153
154#[cfg(feature = "host")]
155mod decode {
156    use super::*;
157    use wiggle::{GuestMemory, GuestPtr, GuestSlice, GuestSliceMut};
158
159    pub(crate) trait ArgDecode<'d, 'r> {
160        type Encoded;
161        type Repr;
162        fn decode_arg<R>(
163            stack: StackedArgs<(Self::Encoded, R)>,
164            vm: &'d dyn GuestMemory,
165        ) -> Result<(Arg<'d, 'r, Self>, StackedArgs<R>)>
166        where
167            Self: Sized;
168        fn decode_repr(repr: &'r mut Self::Repr) -> Result<Self>
169        where
170            Self: Sized;
171    }
172
173    pub(crate) struct Arg<'d, 'r, T: ArgDecode<'d, 'r>>(<T as ArgDecode<'d, 'r>>::Repr);
174    impl<'d, 'r, T: ArgDecode<'d, 'r>> Arg<'d, 'r, T> {
175        pub(crate) fn new(repr: T::Repr) -> Self {
176            Arg(repr)
177        }
178
179        pub(crate) fn extract(&'r mut self) -> Result<T>
180        where
181            'd: 'r,
182        {
183            T::decode_repr(&mut self.0)
184        }
185    }
186
187    impl<A, B> StackedArgs<(A, B)> {
188        pub(crate) fn pop_arg<'d: 'r, 'r, T: ArgDecode<'d, 'r, Encoded = A>>(
189            self,
190            mem: &'d dyn GuestMemory,
191        ) -> Result<(Arg<'d, 'r, T>, StackedArgs<B>)> {
192            T::decode_arg(self, mem)
193        }
194    }
195
196    impl<'d: 'r, 'r: 'a, 'a> ArgDecode<'d, 'r> for &'a [u8] {
197        type Encoded = (IntPtr, IntPtr);
198
199        type Repr = GuestSlice<'d, u8>;
200
201        fn decode_arg<R>(
202            stack: StackedArgs<(Self::Encoded, R)>,
203            mem: &'d dyn GuestMemory,
204        ) -> Result<(Arg<'d, 'r, Self>, StackedArgs<R>)>
205        where
206            Self: Sized,
207        {
208            let ((len, ptr), stack) = stack.pop();
209            let repr: GuestSlice<'_, u8> = GuestPtr::new(mem, (ptr as u32, len as u32))
210                .as_slice()
211                .map_err(|_| OcallError::InvalidParameter)?
212                .ok_or(OcallError::InvalidParameter)?;
213            Ok((Arg(repr), stack))
214        }
215
216        fn decode_repr(repr: &'r mut Self::Repr) -> Result<Self> {
217            Ok(&*repr)
218        }
219    }
220
221    impl<'d: 'r, 'r: 'a, 'a> ArgDecode<'d, 'r> for &'a str {
222        type Encoded = (IntPtr, IntPtr);
223        type Repr = GuestSlice<'d, u8>;
224
225        fn decode_arg<A>(
226            stack: StackedArgs<(Self::Encoded, A)>,
227            vm: &'d dyn GuestMemory,
228        ) -> Result<(Arg<'d, 'r, Self>, StackedArgs<A>)>
229        where
230            Self: Sized,
231        {
232            let (inner, stack) = <&[u8]>::decode_arg(stack, vm)?;
233            Ok((Arg(inner.0), stack))
234        }
235
236        fn decode_repr(repr: &'r mut Self::Repr) -> Result<Self> {
237            core::str::from_utf8(&*repr).or(Err(OcallError::InvalidEncoding))
238        }
239    }
240
241    impl<'d: 'r, 'r: 'a, 'a> ArgDecode<'d, 'r> for &'a mut [u8] {
242        type Encoded = (IntPtr, IntPtr);
243        type Repr = GuestSliceMut<'d, u8>;
244
245        fn decode_arg<A>(
246            stack: StackedArgs<(Self::Encoded, A)>,
247            mem: &'d dyn GuestMemory,
248        ) -> Result<(Arg<'d, 'r, Self>, StackedArgs<A>)>
249        where
250            Self: Sized,
251        {
252            let ((len, ptr), stack) = stack.pop();
253            let repr: GuestSliceMut<'_, u8> = GuestPtr::new(mem, (ptr as u32, len as u32))
254                .as_slice_mut()
255                .map_err(|_| OcallError::InvalidParameter)?
256                .ok_or(OcallError::InvalidParameter)?;
257            Ok((Arg(repr), stack))
258        }
259
260        fn decode_repr(repr: &'r mut Self::Repr) -> Result<Self> {
261            Ok(&mut *repr)
262        }
263    }
264}
265
266impl ArgEncode for &str {
267    type Encoded = (IntPtr, IntPtr);
268
269    fn encode_arg<A>(self, stack: StackedArgs<A>) -> StackedArgs<(Self::Encoded, A)> {
270        let bytes = self.as_bytes();
271        bytes.encode_arg(stack)
272    }
273}
274
275impl ArgEncode for &mut [u8] {
276    type Encoded = (IntPtr, IntPtr);
277
278    fn encode_arg<A>(self, stack: StackedArgs<A>) -> StackedArgs<(Self::Encoded, A)> {
279        let ptr = self.as_mut_ptr() as IntPtr;
280        let len = self.len() as IntPtr;
281        stack.push((len, ptr))
282    }
283}
284
285impl<B> StackedArgs<B> {
286    pub(crate) fn push_arg<Arg: ArgEncode>(self, v: Arg) -> StackedArgs<(Arg::Encoded, B)> {
287        v.encode_arg(self)
288    }
289}
290
291macro_rules! impl_codec_i {
292    ($typ: ty) => {
293        impl I32Convertible for $typ {
294            fn to_i32(&self) -> i32 {
295                *self as i32
296            }
297            fn from_i32(i: i32) -> Result<Self> {
298                if i > <$typ>::MAX as i32 || i < (-<$typ>::MAX - 1) as i32 {
299                    Err(OcallError::InvalidEncoding)
300                } else {
301                    Ok(i as Self)
302                }
303            }
304        }
305    };
306}
307impl_codec_i!(i8);
308impl_codec_i!(i16);
309
310macro_rules! impl_codec_u {
311    ($typ: ty) => {
312        impl I32Convertible for $typ {
313            fn to_i32(&self) -> i32 {
314                *self as i32
315            }
316            fn from_i32(i: i32) -> Result<Self> {
317                if i as u32 > <$typ>::MAX as u32 {
318                    Err(OcallError::InvalidEncoding)
319                } else {
320                    Ok(i as Self)
321                }
322            }
323        }
324    };
325}
326impl_codec_u!(u8);
327impl_codec_u!(u16);
328
329macro_rules! impl_codec {
330    ($typ: ty) => {
331        impl I32Convertible for $typ {
332            fn to_i32(&self) -> i32 {
333                *self as i32
334            }
335            fn from_i32(i: i32) -> Result<Self> {
336                Ok(i as Self)
337            }
338        }
339    };
340}
341impl_codec!(i32);
342impl_codec!(u32);
343
344macro_rules! impl_codec64 {
345    ($typ: ty) => {
346        impl ArgEncode for $typ {
347            type Encoded = (IntPtr, IntPtr);
348
349            fn encode_arg<R>(self, stack: StackedArgs<R>) -> StackedArgs<(Self::Encoded, R)> {
350                let low = (self & 0xffffffff) as IntPtr;
351                let high = ((self >> 32) & 0xffffffff) as IntPtr;
352                stack.push((low, high))
353            }
354        }
355
356        #[cfg(feature = "host")]
357        impl<'d: 'r, 'r> decode::ArgDecode<'d, 'r> for $typ {
358            type Encoded = (IntPtr, IntPtr);
359            type Repr = Self;
360
361            fn decode_arg<R>(
362                stack: StackedArgs<(Self::Encoded, R)>,
363                _vm: &'d dyn wiggle::GuestMemory,
364            ) -> Result<(decode::Arg<'d, 'r, Self>, StackedArgs<R>)>
365            where
366                Self: Sized,
367            {
368                let ((low, high), stack) = stack.pop();
369                let high = ((high as Self) << 32);
370                let v = high & (low as Self);
371                Ok((decode::Arg::new(v), stack))
372            }
373            fn decode_repr(repr: &'r mut Self::Repr) -> Result<Self> {
374                Ok(*repr)
375            }
376        }
377    };
378}
379
380impl_codec64!(i64);
381impl_codec64!(u64);
382
383impl<I: I32Convertible> ArgEncode for I {
384    type Encoded = IntPtr;
385
386    fn encode_arg<R>(self, stack: StackedArgs<R>) -> StackedArgs<(Self::Encoded, R)> {
387        stack.push(self.to_i32() as _)
388    }
389}
390
391#[cfg(feature = "host")]
392impl<'d: 'r, 'r, I: I32Convertible> decode::ArgDecode<'d, 'r> for I {
393    type Encoded = IntPtr;
394    type Repr = i32;
395
396    fn decode_arg<R>(
397        stack: StackedArgs<(Self::Encoded, R)>,
398        _vm: &'d dyn wiggle::GuestMemory,
399    ) -> Result<(decode::Arg<'d, 'r, Self>, StackedArgs<R>)>
400    where
401        Self: Sized,
402    {
403        let (v, stack) = stack.pop();
404        Ok((decode::Arg::new(v as _), stack))
405    }
406
407    fn decode_repr(repr: &'r mut Self::Repr) -> Result<Self> {
408        I::from_i32(*repr)
409    }
410}
411
412impl I32Convertible for bool {
413    fn to_i32(&self) -> i32 {
414        *self as i32
415    }
416    fn from_i32(i: i32) -> Result<Self> {
417        match i {
418            0 => Ok(false),
419            1 => Ok(true),
420            _ => Err(OcallError::InvalidEncoding),
421        }
422    }
423}
424
425impl I32Convertible for OcallError {
426    fn to_i32(&self) -> i32 {
427        *self as u8 as i32
428    }
429    fn from_i32(i: i32) -> Result<Self> {
430        let code = u8::from_i32(i)?;
431        OcallError::try_from(code).or(Err(OcallError::InvalidEncoding))
432    }
433}
434
435impl I32Convertible for () {
436    fn to_i32(&self) -> i32 {
437        0
438    }
439    fn from_i32(i: i32) -> Result<()> {
440        if i == 0 {
441            Ok(())
442        } else {
443            Err(OcallError::InvalidEncoding)
444        }
445    }
446}
447
448impl I32Convertible for Level {
449    fn to_i32(&self) -> i32 {
450        match self {
451            Level::Error => 1,
452            Level::Warn => 2,
453            Level::Info => 3,
454            Level::Debug => 4,
455            Level::Trace => 5,
456        }
457    }
458
459    fn from_i32(i: i32) -> Result<Self> {
460        match i {
461            1 => Ok(Level::Error),
462            2 => Ok(Level::Warn),
463            3 => Ok(Level::Info),
464            4 => Ok(Level::Debug),
465            5 => Ok(Level::Trace),
466            _ => Err(OcallError::InvalidEncoding),
467        }
468    }
469}
470
471impl<A, B> RetEncode for Result<A, B>
472where
473    A: I32Convertible,
474    B: I32Convertible,
475{
476    fn encode_ret(self) -> IntRet {
477        let (tp, val) = match self {
478            Ok(v) => (0, v.to_i32()),
479            Err(err) => (1, err.to_i32()),
480        };
481        ((tp as u32 as i64) << 32) | (val as u32 as i64)
482    }
483}
484
485impl<A, B> RetDecode for Result<A, B>
486where
487    A: I32Convertible,
488    B: I32Convertible,
489{
490    fn decode_ret(encoded: IntRet) -> Self {
491        let tp = ((encoded >> 32) & 0xffffffff) as i32;
492        let val = (encoded & 0xffffffff) as i32;
493        if tp == 0 {
494            Ok(A::from_i32(val).expect("Invalid ocall return"))
495        } else {
496            Err(B::from_i32(val).expect("Invalid ocall return"))
497        }
498    }
499}