stak_dynamic/
primitive_set.rs

1use crate::{error::DynamicError, scheme_value::SchemeValue};
2use alloc::{boxed::Box, string::String, vec, vec::Vec};
3use any_fn::AnyFn;
4use bitvec::bitvec;
5use core::any::TypeId;
6use stak_vm::{Cons, Error, Memory, Number, PrimitiveSet, Type, Value};
7use winter_maybe_async::maybe_async;
8
9const MAXIMUM_ARGUMENT_COUNT: usize = 16;
10
11type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
12type SchemeType = (
13    TypeId,
14    Box<dyn Fn(&Memory, Value) -> Option<any_fn::Value>>,
15    Box<dyn Fn(&mut Memory, any_fn::Value) -> Result<Value, DynamicError>>,
16);
17
18/// A dynamic primitive set equipped with native functions in Rust.
19pub struct DynamicPrimitiveSet<'a, 'b> {
20    functions: &'a mut [(&'a str, AnyFn<'b>)],
21    types: Vec<SchemeType>,
22    values: Vec<Option<any_fn::Value>>,
23}
24
25impl<'a, 'b> DynamicPrimitiveSet<'a, 'b> {
26    /// Creates a primitive set.
27    pub fn new(functions: &'a mut [(&'a str, AnyFn<'b>)]) -> Self {
28        let mut set = Self {
29            functions,
30            types: vec![],
31            values: vec![],
32        };
33
34        // TODO Support more types including `()` and `Vec<u8>`.
35        set.register_type::<bool>();
36        set.register_type::<i8>();
37        set.register_type::<u8>();
38        set.register_type::<i16>();
39        set.register_type::<u16>();
40        set.register_type::<i32>();
41        set.register_type::<u32>();
42        set.register_type::<i64>();
43        set.register_type::<u64>();
44        set.register_type::<f32>();
45        set.register_type::<f64>();
46        set.register_type::<isize>();
47        set.register_type::<usize>();
48        set.register_type::<String>();
49
50        set
51    }
52
53    /// Registers a type compatible between Scheme and Rust.
54    ///
55    /// Values of such types are automatically marshalled when we pass them from
56    /// Scheme to Rust, and vice versa. Marshalling values can lead to the loss
57    /// of information (e.g. floating-point numbers in Scheme marshalled
58    /// into integers in Rust.)
59    pub fn register_type<T: SchemeValue + 'static>(&mut self) {
60        self.types.push((
61            TypeId::of::<T>(),
62            Box::new(|memory, value| T::from_scheme(memory, value).map(any_fn::value)),
63            Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
64        ));
65    }
66
67    fn collect_garbages(&mut self, memory: &Memory) {
68        let mut marks = bitvec![0; self.values.len()];
69
70        for index in 0..(memory.allocation_index() / 2) {
71            let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
72
73            if memory.cdr(cons).tag() != Type::Foreign as _ {
74                continue;
75            }
76
77            marks.set(memory.car(cons).assume_number().to_i64() as _, true);
78        }
79
80        for (index, mark) in marks.into_iter().enumerate() {
81            if !mark {
82                self.values[index] = None;
83            }
84        }
85    }
86
87    // TODO Optimize this with `BitSlice::first_zero()`.
88    fn find_free(&self) -> Option<usize> {
89        self.values.iter().position(Option::is_none)
90    }
91
92    fn allocate(&mut self, memory: &Memory) -> usize {
93        if let Some(index) = self.find_free() {
94            index
95        } else if let Some(index) = {
96            self.collect_garbages(memory);
97            self.find_free()
98        } {
99            index
100        } else {
101            self.values.push(None);
102            self.values.len() - 1
103        }
104    }
105
106    fn convert_from_scheme(
107        &self,
108        memory: &Memory,
109        value: Value,
110        type_id: TypeId,
111    ) -> Option<any_fn::Value> {
112        for (id, from, _) in &self.types {
113            if type_id == *id {
114                return from(memory, value);
115            }
116        }
117
118        None
119    }
120
121    fn convert_into_scheme(
122        &mut self,
123        memory: &mut Memory,
124        value: any_fn::Value,
125    ) -> Result<Value, DynamicError> {
126        for (id, _, into) in &self.types {
127            if value.type_id()? == *id {
128                return into(memory, value);
129            }
130        }
131
132        let index = self.allocate(memory);
133
134        self.values[index] = Some(value);
135
136        Ok(memory
137            .allocate(
138                Number::from_i64(index as _).into(),
139                memory.null().set_tag(Type::Foreign as _).into(),
140            )?
141            .into())
142    }
143}
144
145impl PrimitiveSet for DynamicPrimitiveSet<'_, '_> {
146    type Error = DynamicError;
147
148    #[maybe_async]
149    fn operate(&mut self, memory: &mut Memory<'_>, primitive: usize) -> Result<(), Self::Error> {
150        if primitive == 0 {
151            memory.set_register(memory.null());
152
153            for (name, _) in self.functions.iter().rev() {
154                let list = memory.cons(memory.null().into(), memory.register())?;
155                memory.set_register(list);
156                let string = memory.build_raw_string(name)?;
157                memory.set_car(memory.register(), string.into());
158            }
159
160            memory.push(memory.register().into())?;
161
162            Ok(())
163        } else {
164            let primitive = primitive - 1;
165            let (_, function) = self
166                .functions
167                .get(primitive)
168                .ok_or(Error::IllegalPrimitive)?;
169
170            let mut arguments = (0..function.arity())
171                .map(|_| memory.pop())
172                .collect::<ArgumentVec<_>>();
173            arguments.reverse();
174
175            let cloned_arguments = {
176                arguments
177                    .iter()
178                    .enumerate()
179                    .map(|(index, &value)| {
180                        self.convert_from_scheme(memory, value, function.parameter_types()[index])
181                    })
182                    .collect::<ArgumentVec<_>>()
183            };
184
185            let mut copied_arguments = ArgumentVec::new();
186
187            for &value in &arguments {
188                let value =
189                    if value.is_cons() && memory.cdr_value(value).tag() == Type::Foreign as _ {
190                        Some(
191                            self.values
192                                .get(memory.car_value(value).assume_number().to_i64() as usize)
193                                .ok_or(DynamicError::ValueIndex)?
194                                .as_ref()
195                                .ok_or(DynamicError::ValueIndex)?,
196                        )
197                    } else {
198                        None
199                    };
200
201                copied_arguments
202                    .push(value)
203                    .map_err(|_| Error::ArgumentCount)?;
204            }
205
206            let value = self
207                .functions
208                .get_mut(primitive)
209                .ok_or(Error::IllegalPrimitive)?
210                .1
211                .call(
212                    copied_arguments
213                        .into_iter()
214                        .enumerate()
215                        .map(|(index, value)| {
216                            cloned_arguments[index]
217                                .as_ref()
218                                .map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
219                        })
220                        .collect::<Result<ArgumentVec<_>, DynamicError>>()?
221                        .as_slice(),
222                )?;
223
224            let value = self.convert_into_scheme(memory, value)?;
225            memory.push(value)?;
226
227            Ok(())
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use any_fn::{Ref, r#fn, value};
236    use winter_maybe_async::maybe_await;
237
238    const HEAP_SIZE: usize = 1 << 8;
239
240    struct Foo {
241        bar: usize,
242    }
243
244    impl Foo {
245        const fn new(bar: usize) -> Self {
246            Self { bar }
247        }
248
249        const fn bar(&self) -> usize {
250            self.bar
251        }
252
253        fn baz(&mut self, value: usize) {
254            self.bar += value;
255        }
256    }
257
258    #[test]
259    fn create() {
260        let mut functions = [
261            ("make-foo", r#fn(Foo::new)),
262            ("foo-bar", r#fn::<(Ref<_>,), _>(Foo::bar)),
263            ("foo-baz", r#fn(Foo::baz)),
264        ];
265
266        DynamicPrimitiveSet::new(&mut functions);
267    }
268
269    #[test]
270    fn allocate_two() {
271        let mut heap = [Default::default(); HEAP_SIZE];
272        let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
273        let mut memory = Memory::new(&mut heap).unwrap();
274
275        let index = primitive_set.allocate(&memory);
276        primitive_set.values[index] = Some(value(42usize));
277        assert_eq!(index, 0);
278        assert_eq!(primitive_set.find_free(), None);
279
280        let cons = memory
281            .allocate(
282                Number::from_i64(index as _).into(),
283                memory.null().set_tag(Type::Foreign as _).into(),
284            )
285            .unwrap();
286        memory.push(cons.into()).unwrap();
287
288        let index = primitive_set.allocate(&memory);
289        primitive_set.values[index] = Some(value(42usize));
290        assert_eq!(index, 1);
291        assert_eq!(primitive_set.find_free(), None);
292    }
293
294    mod garbage_collection {
295        use super::*;
296
297        #[test]
298        fn collect_none() {
299            let mut heap = [Default::default(); HEAP_SIZE];
300            let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
301
302            primitive_set.collect_garbages(&Memory::new(&mut heap).unwrap());
303        }
304
305        #[tokio::test]
306        async fn collect_one() {
307            let mut heap = [Default::default(); HEAP_SIZE];
308            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
309            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
310            let mut memory = Memory::new(&mut heap).unwrap();
311
312            maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
313
314            assert_eq!(primitive_set.find_free(), None);
315
316            // Pop a return value from the foreign primitive.
317            memory.pop();
318            memory.collect_garbages(None).unwrap();
319
320            primitive_set.collect_garbages(&memory);
321
322            assert_eq!(primitive_set.find_free(), Some(0));
323        }
324
325        #[tokio::test]
326        async fn keep_one() {
327            let mut heap = [Default::default(); HEAP_SIZE];
328            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
329            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
330            let mut memory = Memory::new(&mut heap).unwrap();
331
332            maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
333
334            assert_eq!(primitive_set.find_free(), None);
335
336            primitive_set.collect_garbages(&memory);
337
338            assert_eq!(primitive_set.find_free(), None);
339        }
340    }
341}