stak_dynamic/
primitive_set.rs

1use crate::{error::DynamicError, scheme_value::SchemeValue};
2use alloc::{boxed::Box, string::String, vec, vec::Vec};
3use any_fn::AnyFn;
4use bitvec::bitvec;
5use core::any::TypeId;
6use stak_vm::{Cons, Error, Memory, Number, PrimitiveSet, Type, Value};
7
8const MAXIMUM_ARGUMENT_COUNT: usize = 16;
9
10type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
11type SchemeType = (
12    TypeId,
13    Box<dyn Fn(&Memory, Value) -> Option<any_fn::Value>>,
14    Box<dyn Fn(&mut Memory, any_fn::Value) -> Result<Value, DynamicError>>,
15);
16
17/// A dynamic primitive set equipped with native functions in Rust.
18pub struct DynamicPrimitiveSet<'a, 'b> {
19    functions: &'a mut [(&'a str, AnyFn<'b>)],
20    types: Vec<SchemeType>,
21    values: Vec<Option<any_fn::Value>>,
22}
23
24impl<'a, 'b> DynamicPrimitiveSet<'a, 'b> {
25    /// Creates a primitive set.
26    pub fn new(functions: &'a mut [(&'a str, AnyFn<'b>)]) -> Self {
27        let mut set = Self {
28            functions,
29            types: vec![],
30            values: vec![],
31        };
32
33        // TODO Support more types including `()` and `Vec<u8>`.
34        set.register_type::<bool>();
35        set.register_type::<i8>();
36        set.register_type::<u8>();
37        set.register_type::<i16>();
38        set.register_type::<u16>();
39        set.register_type::<i32>();
40        set.register_type::<u32>();
41        set.register_type::<i64>();
42        set.register_type::<u64>();
43        set.register_type::<f32>();
44        set.register_type::<f64>();
45        set.register_type::<isize>();
46        set.register_type::<usize>();
47        set.register_type::<String>();
48
49        set
50    }
51
52    /// Registers a type compatible between Scheme and Rust.
53    ///
54    /// Values of such types are automatically marshalled when we pass them from
55    /// Scheme to Rust, and vice versa. Marshalling values can lead to the loss
56    /// of information (e.g. floating-point numbers in Scheme marshalled
57    /// into integers in Rust.)
58    pub fn register_type<T: SchemeValue + 'static>(&mut self) {
59        self.types.push((
60            TypeId::of::<T>(),
61            Box::new(|memory, value| T::from_scheme(memory, value).map(any_fn::value)),
62            Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
63        ));
64    }
65
66    fn collect_garbages(&mut self, memory: &Memory) {
67        let mut marks = bitvec![0; self.values.len()];
68
69        for index in 0..(memory.allocation_index() / 2) {
70            let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
71
72            if memory.cdr(cons).tag() != Type::Foreign as _ {
73                continue;
74            }
75
76            marks.set(memory.car(cons).assume_number().to_i64() as _, true);
77        }
78
79        for (index, mark) in marks.into_iter().enumerate() {
80            if !mark {
81                self.values[index] = None;
82            }
83        }
84    }
85
86    // TODO Optimize this with `BitSlice::first_zero()`.
87    fn find_free(&self) -> Option<usize> {
88        self.values.iter().position(Option::is_none)
89    }
90
91    fn allocate(&mut self, memory: &Memory) -> usize {
92        if let Some(index) = self.find_free() {
93            index
94        } else if let Some(index) = {
95            self.collect_garbages(memory);
96            self.find_free()
97        } {
98            index
99        } else {
100            self.values.push(None);
101            self.values.len() - 1
102        }
103    }
104
105    fn convert_from_scheme(
106        &self,
107        memory: &Memory,
108        value: Value,
109        type_id: TypeId,
110    ) -> Option<any_fn::Value> {
111        for (id, from, _) in &self.types {
112            if type_id == *id {
113                return from(memory, value);
114            }
115        }
116
117        None
118    }
119
120    fn convert_into_scheme(
121        &mut self,
122        memory: &mut Memory,
123        value: any_fn::Value,
124    ) -> Result<Value, DynamicError> {
125        for (id, _, into) in &self.types {
126            if value.type_id()? == *id {
127                return into(memory, value);
128            }
129        }
130
131        let index = self.allocate(memory);
132
133        self.values[index] = Some(value);
134
135        Ok(memory
136            .allocate(
137                Number::from_i64(index as _).into(),
138                memory.null().set_tag(Type::Foreign as _).into(),
139            )?
140            .into())
141    }
142}
143
144impl PrimitiveSet for DynamicPrimitiveSet<'_, '_> {
145    type Error = DynamicError;
146
147    fn operate(&mut self, memory: &mut Memory, primitive: usize) -> Result<(), Self::Error> {
148        if primitive == 0 {
149            memory.set_register(memory.null());
150
151            for (name, _) in self.functions.iter().rev() {
152                let list = memory.cons(memory.null().into(), memory.register())?;
153                memory.set_register(list);
154                let string = memory.build_raw_string(name)?;
155                memory.set_car(memory.register(), string.into());
156            }
157
158            memory.push(memory.register().into())?;
159
160            Ok(())
161        } else {
162            let primitive = primitive - 1;
163            let (_, function) = self
164                .functions
165                .get(primitive)
166                .ok_or(Error::IllegalPrimitive)?;
167
168            let mut arguments = (0..function.arity())
169                .map(|_| memory.pop())
170                .collect::<ArgumentVec<_>>();
171            arguments.reverse();
172
173            let cloned_arguments = {
174                arguments
175                    .iter()
176                    .enumerate()
177                    .map(|(index, &value)| {
178                        self.convert_from_scheme(memory, value, function.parameter_types()[index])
179                    })
180                    .collect::<ArgumentVec<_>>()
181            };
182
183            let mut copied_arguments = ArgumentVec::new();
184
185            for &value in &arguments {
186                let value =
187                    if value.is_cons() && memory.cdr_value(value).tag() == Type::Foreign as _ {
188                        Some(
189                            self.values
190                                .get(memory.car_value(value).assume_number().to_i64() as usize)
191                                .ok_or(DynamicError::ValueIndex)?
192                                .as_ref()
193                                .ok_or(DynamicError::ValueIndex)?,
194                        )
195                    } else {
196                        None
197                    };
198
199                copied_arguments
200                    .push(value)
201                    .map_err(|_| Error::ArgumentCount)?;
202            }
203
204            let value = self
205                .functions
206                .get_mut(primitive)
207                .ok_or(Error::IllegalPrimitive)?
208                .1
209                .call(
210                    copied_arguments
211                        .into_iter()
212                        .enumerate()
213                        .map(|(index, value)| {
214                            cloned_arguments[index]
215                                .as_ref()
216                                .map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
217                        })
218                        .collect::<Result<ArgumentVec<_>, DynamicError>>()?
219                        .as_slice(),
220                )?;
221
222            let value = self.convert_into_scheme(memory, value)?;
223            memory.push(value)?;
224
225            Ok(())
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use any_fn::{Ref, r#fn, value};
234
235    const HEAP_SIZE: usize = 1 << 8;
236
237    struct Foo {
238        bar: usize,
239    }
240
241    impl Foo {
242        const fn new(bar: usize) -> Self {
243            Self { bar }
244        }
245
246        const fn bar(&self) -> usize {
247            self.bar
248        }
249
250        fn baz(&mut self, value: usize) {
251            self.bar += value;
252        }
253    }
254
255    #[test]
256    fn create() {
257        let mut functions = [
258            ("make-foo", r#fn(Foo::new)),
259            ("foo-bar", r#fn::<(Ref<_>,), _>(Foo::bar)),
260            ("foo-baz", r#fn(Foo::baz)),
261        ];
262
263        DynamicPrimitiveSet::new(&mut functions);
264    }
265
266    #[test]
267    fn allocate_two() {
268        let mut heap = [Default::default(); HEAP_SIZE];
269        let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
270        let mut memory = Memory::new(&mut heap).unwrap();
271
272        let index = primitive_set.allocate(&memory);
273        primitive_set.values[index] = Some(value(42usize));
274        assert_eq!(index, 0);
275        assert_eq!(primitive_set.find_free(), None);
276
277        let cons = memory
278            .allocate(
279                Number::from_i64(index as _).into(),
280                memory.null().set_tag(Type::Foreign as _).into(),
281            )
282            .unwrap();
283        memory.push(cons.into()).unwrap();
284
285        let index = primitive_set.allocate(&memory);
286        primitive_set.values[index] = Some(value(42usize));
287        assert_eq!(index, 1);
288        assert_eq!(primitive_set.find_free(), None);
289    }
290
291    mod garbage_collection {
292        use super::*;
293
294        #[test]
295        fn collect_none() {
296            let mut heap = [Default::default(); HEAP_SIZE];
297            let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
298
299            primitive_set.collect_garbages(&Memory::new(&mut heap).unwrap());
300        }
301
302        #[test]
303        fn collect_one() {
304            let mut heap = [Default::default(); HEAP_SIZE];
305            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
306            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
307            let mut memory = Memory::new(&mut heap).unwrap();
308
309            primitive_set.operate(&mut memory, 1).unwrap();
310
311            assert_eq!(primitive_set.find_free(), None);
312
313            // Pop a return value from the foreign primitive.
314            memory.pop();
315            memory.collect_garbages(None).unwrap();
316
317            primitive_set.collect_garbages(&memory);
318
319            assert_eq!(primitive_set.find_free(), Some(0));
320        }
321
322        #[test]
323        fn keep_one() {
324            let mut heap = [Default::default(); HEAP_SIZE];
325            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
326            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
327            let mut memory = Memory::new(&mut heap).unwrap();
328
329            primitive_set.operate(&mut memory, 1).unwrap();
330
331            assert_eq!(primitive_set.find_free(), None);
332
333            primitive_set.collect_garbages(&memory);
334
335            assert_eq!(primitive_set.find_free(), None);
336        }
337    }
338}