stak_dynamic/
primitive_set.rs

1use crate::{error::DynamicError, scheme_value::SchemeValue};
2use alloc::{boxed::Box, string::String, vec, vec::Vec};
3use any_fn::AnyFn;
4use bitvec::bitvec;
5use core::any::TypeId;
6use stak_vm::{Cons, Error, Memory, Number, PrimitiveSet, Type, Value};
7use winter_maybe_async::maybe_async;
8
9const MAXIMUM_ARGUMENT_COUNT: usize = 16;
10
11type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
12type SchemeType = (
13    TypeId,
14    Box<dyn Fn(&Memory, Value) -> Result<Option<any_fn::Value>, DynamicError>>,
15    Box<dyn Fn(&mut Memory, any_fn::Value) -> Result<Value, DynamicError>>,
16);
17
18/// A dynamic primitive set equipped with native functions in Rust.
19pub struct DynamicPrimitiveSet<'a, 'b> {
20    functions: &'a mut [(&'a str, AnyFn<'b>)],
21    types: Vec<SchemeType>,
22    values: Vec<Option<any_fn::Value>>,
23}
24
25impl<'a, 'b> DynamicPrimitiveSet<'a, 'b> {
26    /// Creates a primitive set.
27    pub fn new(functions: &'a mut [(&'a str, AnyFn<'b>)]) -> Self {
28        let mut set = Self {
29            functions,
30            types: vec![],
31            values: vec![],
32        };
33
34        // TODO Support more types including `()` and `Vec<u8>`.
35        set.register_type::<bool>();
36        set.register_type::<i8>();
37        set.register_type::<u8>();
38        set.register_type::<i16>();
39        set.register_type::<u16>();
40        set.register_type::<i32>();
41        set.register_type::<u32>();
42        set.register_type::<i64>();
43        set.register_type::<u64>();
44        set.register_type::<f32>();
45        set.register_type::<f64>();
46        set.register_type::<isize>();
47        set.register_type::<usize>();
48        set.register_type::<String>();
49
50        set
51    }
52
53    /// Registers a type compatible between Scheme and Rust.
54    ///
55    /// Values of such types are automatically marshalled when we pass them from
56    /// Scheme to Rust, and vice versa. Marshalling values can lead to the loss
57    /// of information (e.g. floating-point numbers in Scheme marshalled
58    /// into integers in Rust.)
59    pub fn register_type<T: SchemeValue + 'static>(&mut self) {
60        self.types.push((
61            TypeId::of::<T>(),
62            Box::new(|memory, value| Ok(T::from_scheme(memory, value)?.map(any_fn::value))),
63            Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
64        ));
65    }
66
67    fn collect_garbages(&mut self, memory: &Memory) -> Result<(), Error> {
68        let mut marks = bitvec![0; self.values.len()];
69
70        for index in 0..(memory.allocation_index() / 2) {
71            let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
72
73            if memory.cdr(cons)?.tag() != Type::Foreign as _ {
74                continue;
75            }
76
77            marks.set(memory.car(cons)?.assume_number().to_i64() as _, true);
78        }
79
80        for (index, mark) in marks.into_iter().enumerate() {
81            if !mark {
82                self.values[index] = None;
83            }
84        }
85
86        Ok(())
87    }
88
89    // TODO Optimize this with `BitSlice::first_zero()`.
90    fn find_free(&self) -> Option<usize> {
91        self.values.iter().position(Option::is_none)
92    }
93
94    fn allocate(&mut self, memory: &Memory) -> Result<usize, Error> {
95        Ok(if let Some(index) = self.find_free() {
96            index
97        } else if let Some(index) = {
98            self.collect_garbages(memory)?;
99            self.find_free()
100        } {
101            index
102        } else {
103            self.values.push(None);
104            self.values.len() - 1
105        })
106    }
107
108    fn convert_from_scheme(
109        &self,
110        memory: &Memory,
111        value: Value,
112        type_id: TypeId,
113    ) -> Result<Option<any_fn::Value>, DynamicError> {
114        for (id, from, _) in &self.types {
115            if type_id == *id {
116                return from(memory, value);
117            }
118        }
119
120        Ok(None)
121    }
122
123    fn convert_into_scheme(
124        &mut self,
125        memory: &mut Memory,
126        value: any_fn::Value,
127    ) -> Result<Value, DynamicError> {
128        for (id, _, into) in &self.types {
129            if value.type_id()? == *id {
130                return into(memory, value);
131            }
132        }
133
134        let index = self.allocate(memory)?;
135
136        self.values[index] = Some(value);
137
138        Ok(memory
139            .allocate(
140                Number::from_i64(index as _).into(),
141                memory.null()?.set_tag(Type::Foreign as _).into(),
142            )?
143            .into())
144    }
145}
146
147impl PrimitiveSet for DynamicPrimitiveSet<'_, '_> {
148    type Error = DynamicError;
149
150    #[maybe_async]
151    fn operate(&mut self, memory: &mut Memory<'_>, primitive: usize) -> Result<(), Self::Error> {
152        if primitive == 0 {
153            memory.set_register(memory.null()?);
154
155            for (name, _) in self.functions.iter().rev() {
156                let list = memory.cons(memory.null()?.into(), memory.register())?;
157                memory.set_register(list);
158                let string = memory.build_raw_string(name)?;
159                memory.set_car(memory.register(), string.into())?;
160            }
161
162            memory.push(memory.register().into())?;
163
164            Ok(())
165        } else {
166            let primitive = primitive - 1;
167            let (_, function) = self
168                .functions
169                .get(primitive)
170                .ok_or(Error::IllegalPrimitive)?;
171
172            let mut arguments = (0..function.arity())
173                .map(|_| memory.pop())
174                .collect::<Result<ArgumentVec<_>, _>>()?;
175            arguments.reverse();
176
177            let cloned_arguments = {
178                arguments
179                    .iter()
180                    .enumerate()
181                    .map(|(index, &value)| {
182                        self.convert_from_scheme(memory, value, function.parameter_types()[index])
183                    })
184                    .collect::<Result<ArgumentVec<_>, _>>()?
185            };
186
187            let mut copied_arguments = ArgumentVec::new();
188
189            for &value in &arguments {
190                let value =
191                    if value.is_cons() && memory.cdr_value(value)?.tag() == Type::Foreign as _ {
192                        Some(
193                            self.values
194                                .get(memory.car_value(value)?.assume_number().to_i64() as usize)
195                                .ok_or(DynamicError::ValueIndex)?
196                                .as_ref()
197                                .ok_or(DynamicError::ValueIndex)?,
198                        )
199                    } else {
200                        None
201                    };
202
203                copied_arguments
204                    .push(value)
205                    .map_err(|_| Error::ArgumentCount)?;
206            }
207
208            let value = self
209                .functions
210                .get_mut(primitive)
211                .ok_or(Error::IllegalPrimitive)?
212                .1
213                .call(
214                    copied_arguments
215                        .into_iter()
216                        .enumerate()
217                        .map(|(index, value)| {
218                            cloned_arguments[index]
219                                .as_ref()
220                                .map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
221                        })
222                        .collect::<Result<ArgumentVec<_>, DynamicError>>()?
223                        .as_slice(),
224                )?;
225
226            let value = self.convert_into_scheme(memory, value)?;
227            memory.push(value)?;
228
229            Ok(())
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use any_fn::{Ref, r#fn, value};
238    use winter_maybe_async::maybe_await;
239
240    const HEAP_SIZE: usize = 1 << 8;
241
242    struct Foo {
243        bar: usize,
244    }
245
246    impl Foo {
247        const fn new(bar: usize) -> Self {
248            Self { bar }
249        }
250
251        const fn bar(&self) -> usize {
252            self.bar
253        }
254
255        fn baz(&mut self, value: usize) {
256            self.bar += value;
257        }
258    }
259
260    #[test]
261    fn create() {
262        let mut functions = [
263            ("make-foo", r#fn(Foo::new)),
264            ("foo-bar", r#fn::<(Ref<_>,), _>(Foo::bar)),
265            ("foo-baz", r#fn(Foo::baz)),
266        ];
267
268        DynamicPrimitiveSet::new(&mut functions);
269    }
270
271    #[test]
272    fn allocate_two() {
273        let mut heap = [Default::default(); HEAP_SIZE];
274        let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
275        let mut memory = Memory::new(&mut heap).unwrap();
276
277        let index = primitive_set.allocate(&memory).unwrap();
278        primitive_set.values[index] = Some(value(42usize));
279        assert_eq!(index, 0);
280        assert_eq!(primitive_set.find_free(), None);
281
282        let cons = memory
283            .allocate(
284                Number::from_i64(index as _).into(),
285                memory.null().unwrap().set_tag(Type::Foreign as _).into(),
286            )
287            .unwrap();
288        memory.push(cons.into()).unwrap();
289
290        let index = primitive_set.allocate(&memory).unwrap();
291        primitive_set.values[index] = Some(value(42usize));
292        assert_eq!(index, 1);
293        assert_eq!(primitive_set.find_free(), None);
294    }
295
296    mod garbage_collection {
297        use super::*;
298
299        #[test]
300        fn collect_none() {
301            let mut heap = [Default::default(); HEAP_SIZE];
302            let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
303
304            primitive_set
305                .collect_garbages(&Memory::new(&mut heap).unwrap())
306                .unwrap();
307        }
308
309        #[tokio::test]
310        async fn collect_one() {
311            let mut heap = [Default::default(); HEAP_SIZE];
312            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
313            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
314            let mut memory = Memory::new(&mut heap).unwrap();
315
316            maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
317
318            assert_eq!(primitive_set.find_free(), None);
319
320            // Pop a return value from the foreign primitive.
321            memory.pop().unwrap();
322            memory.collect_garbages(None).unwrap();
323
324            primitive_set.collect_garbages(&memory).unwrap();
325
326            assert_eq!(primitive_set.find_free(), Some(0));
327        }
328
329        #[tokio::test]
330        async fn keep_one() {
331            let mut heap = [Default::default(); HEAP_SIZE];
332            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
333            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
334            let mut memory = Memory::new(&mut heap).unwrap();
335
336            maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
337
338            assert_eq!(primitive_set.find_free(), None);
339
340            primitive_set.collect_garbages(&memory).unwrap();
341
342            assert_eq!(primitive_set.find_free(), None);
343        }
344    }
345}