stak_dynamic/
primitive_set.rs

1use crate::{error::DynamicError, scheme_value::SchemeValue};
2use alloc::{boxed::Box, string::String, vec, vec::Vec};
3use any_fn::AnyFn;
4use bitvec::bitvec;
5use core::any::TypeId;
6use stak_vm::{Cons, Error, Memory, Number, PrimitiveSet, Type, Value};
7
8const MAXIMUM_ARGUMENT_COUNT: usize = 16;
9
10type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
11type SchemeType = (
12    TypeId,
13    Box<dyn Fn(&Memory, Value) -> Option<any_fn::Value>>,
14    Box<dyn Fn(&mut Memory, any_fn::Value) -> Result<Value, DynamicError>>,
15);
16
17/// A dynamic primitive set equipped with native functions in Rust.
18pub struct DynamicPrimitiveSet<'a, 'b> {
19    functions: &'a mut [(&'a str, AnyFn<'b>)],
20    types: Vec<SchemeType>,
21    values: Vec<Option<any_fn::Value>>,
22}
23
24impl<'a, 'b> DynamicPrimitiveSet<'a, 'b> {
25    /// Creates a primitive set.
26    pub fn new(functions: &'a mut [(&'a str, AnyFn<'b>)]) -> Self {
27        let mut set = Self {
28            functions,
29            types: vec![],
30            values: vec![],
31        };
32
33        // TODO Support more types including `()` and `Vec<u8>`.
34        set.register_type::<bool>();
35        set.register_type::<i8>();
36        set.register_type::<u8>();
37        set.register_type::<i16>();
38        set.register_type::<u16>();
39        set.register_type::<i32>();
40        set.register_type::<u32>();
41        set.register_type::<i64>();
42        set.register_type::<u64>();
43        set.register_type::<f32>();
44        set.register_type::<f64>();
45        set.register_type::<isize>();
46        set.register_type::<usize>();
47        set.register_type::<String>();
48
49        set
50    }
51
52    /// Registers a type compatible between Scheme and Rust.
53    ///
54    /// Values of such types are automatically marshalled when we pass them from
55    /// Scheme to Rust, and vice versa. Marshalling values can lead to the loss
56    /// of information (e.g. floating-point numbers in Scheme marshalled
57    /// into integers in Rust.)
58    pub fn register_type<T: SchemeValue + 'static>(&mut self) {
59        self.types.push((
60            TypeId::of::<T>(),
61            Box::new(|memory, value| T::from_scheme(memory, value).map(any_fn::value)),
62            Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
63        ));
64    }
65
66    fn collect_garbages(&mut self, memory: &Memory) {
67        let mut marks = bitvec![0; self.values.len()];
68
69        for index in 0..(memory.allocation_index() / 2) {
70            let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
71
72            if memory.cdr(cons).tag() != Type::Foreign as _ {
73                continue;
74            }
75
76            let index = memory.car(cons).assume_number().to_i64() as _;
77
78            // Be conservative as foreign type tags can be used for something else.
79            if index >= self.values.len() {
80                continue;
81            }
82
83            marks.set(index, true);
84        }
85
86        for (index, mark) in marks.into_iter().enumerate() {
87            if !mark {
88                self.values[index] = None;
89            }
90        }
91    }
92
93    // TODO Optimize this with `BitSlice::first_zero()`.
94    fn find_free(&self) -> Option<usize> {
95        self.values.iter().position(Option::is_none)
96    }
97
98    fn allocate(&mut self, memory: &Memory) -> usize {
99        if let Some(index) = self.find_free() {
100            index
101        } else if let Some(index) = {
102            self.collect_garbages(memory);
103            self.find_free()
104        } {
105            index
106        } else {
107            self.values.push(None);
108            self.values.len() - 1
109        }
110    }
111
112    fn convert_from_scheme(
113        &self,
114        memory: &Memory,
115        value: Value,
116        type_id: TypeId,
117    ) -> Option<any_fn::Value> {
118        for (id, from, _) in &self.types {
119            if type_id == *id {
120                return from(memory, value);
121            }
122        }
123
124        None
125    }
126
127    fn convert_into_scheme(
128        &mut self,
129        memory: &mut Memory,
130        value: any_fn::Value,
131    ) -> Result<Value, DynamicError> {
132        for (id, _, into) in &self.types {
133            if value.type_id()? == *id {
134                return into(memory, value);
135            }
136        }
137
138        let index = self.allocate(memory);
139
140        self.values[index] = Some(value);
141
142        Ok(memory
143            .allocate(
144                Number::from_i64(index as _).into(),
145                memory.null().set_tag(Type::Foreign as _).into(),
146            )?
147            .into())
148    }
149}
150
151impl PrimitiveSet for DynamicPrimitiveSet<'_, '_> {
152    type Error = DynamicError;
153
154    fn operate(&mut self, memory: &mut Memory, primitive: usize) -> Result<(), Self::Error> {
155        if primitive == 0 {
156            memory.set_register(memory.null());
157
158            for (name, _) in self.functions.iter().rev() {
159                let list = memory.cons(memory.null().into(), memory.register())?;
160                memory.set_register(list);
161                let string = memory.build_string(name)?;
162                memory.set_car(memory.register(), string.into());
163            }
164
165            memory.push(memory.register().into())?;
166
167            Ok(())
168        } else {
169            let primitive = primitive - 1;
170            let (_, function) = self
171                .functions
172                .get(primitive)
173                .ok_or(Error::IllegalPrimitive)?;
174
175            let mut arguments = (0..function.arity())
176                .map(|_| memory.pop())
177                .collect::<ArgumentVec<_>>();
178            arguments.reverse();
179
180            let cloned_arguments = {
181                arguments
182                    .iter()
183                    .enumerate()
184                    .map(|(index, &value)| {
185                        self.convert_from_scheme(memory, value, function.parameter_types()[index])
186                    })
187                    .collect::<ArgumentVec<_>>()
188            };
189
190            let mut copied_arguments = ArgumentVec::new();
191
192            for &value in &arguments {
193                let value =
194                    if value.is_cons() && memory.cdr_value(value).tag() == Type::Foreign as _ {
195                        Some(
196                            self.values
197                                .get(memory.car_value(value).assume_number().to_i64() as usize)
198                                .ok_or(DynamicError::ValueIndex)?
199                                .as_ref()
200                                .ok_or(DynamicError::ValueIndex)?,
201                        )
202                    } else {
203                        None
204                    };
205
206                copied_arguments
207                    .push(value)
208                    .map_err(|_| Error::ArgumentCount)?;
209            }
210
211            let value = self
212                .functions
213                .get_mut(primitive)
214                .ok_or(Error::IllegalPrimitive)?
215                .1
216                .call(
217                    copied_arguments
218                        .into_iter()
219                        .enumerate()
220                        .map(|(index, value)| {
221                            cloned_arguments[index]
222                                .as_ref()
223                                .map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
224                        })
225                        .collect::<Result<ArgumentVec<_>, DynamicError>>()?
226                        .as_slice(),
227                )?;
228
229            let value = self.convert_into_scheme(memory, value)?;
230            memory.push(value)?;
231
232            Ok(())
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use any_fn::{Ref, r#fn, value};
241
242    const HEAP_SIZE: usize = 1 << 8;
243
244    struct Foo {
245        bar: usize,
246    }
247
248    impl Foo {
249        const fn new(bar: usize) -> Self {
250            Self { bar }
251        }
252
253        const fn bar(&self) -> usize {
254            self.bar
255        }
256
257        fn baz(&mut self, value: usize) {
258            self.bar += value;
259        }
260    }
261
262    fn invalidate_foreign_values(memory: &mut Memory) {
263        for index in 0..(memory.size() / 2) {
264            let cons = Cons::new((2 * index) as _);
265
266            if memory.cdr(cons).tag() == Type::Foreign as _ {
267                memory.set_car(cons, Number::from_i64(1 << 16).into());
268            }
269        }
270    }
271
272    #[test]
273    fn create() {
274        let mut functions = [
275            ("make-foo", r#fn(Foo::new)),
276            ("foo-bar", r#fn::<(Ref<_>,), _>(Foo::bar)),
277            ("foo-baz", r#fn(Foo::baz)),
278        ];
279
280        DynamicPrimitiveSet::new(&mut functions);
281    }
282
283    #[test]
284    fn allocate_two() {
285        let mut heap = [Default::default(); HEAP_SIZE];
286        let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
287        let mut memory = Memory::new(&mut heap).unwrap();
288
289        let index = primitive_set.allocate(&memory);
290        primitive_set.values[index] = Some(value(42usize));
291        assert_eq!(index, 0);
292        assert_eq!(primitive_set.find_free(), None);
293
294        let cons = memory
295            .allocate(
296                Number::from_i64(index as _).into(),
297                memory.null().set_tag(Type::Foreign as _).into(),
298            )
299            .unwrap();
300        memory.push(cons.into()).unwrap();
301
302        let index = primitive_set.allocate(&memory);
303        primitive_set.values[index] = Some(value(42usize));
304        assert_eq!(index, 1);
305        assert_eq!(primitive_set.find_free(), None);
306    }
307
308    mod garbage_collection {
309        use super::*;
310
311        #[test]
312        fn collect_none() {
313            let mut heap = [Default::default(); HEAP_SIZE];
314            let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
315
316            primitive_set.collect_garbages(&Memory::new(&mut heap).unwrap());
317        }
318
319        #[test]
320        fn collect_one() {
321            let mut heap = [Default::default(); HEAP_SIZE];
322            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
323            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
324            let mut memory = Memory::new(&mut heap).unwrap();
325
326            primitive_set.operate(&mut memory, 1).unwrap();
327
328            assert_eq!(primitive_set.find_free(), None);
329
330            invalidate_foreign_values(&mut memory);
331
332            primitive_set.collect_garbages(&memory);
333
334            assert_eq!(primitive_set.find_free(), Some(0));
335        }
336
337        #[test]
338        fn keep_one() {
339            let mut heap = [Default::default(); HEAP_SIZE];
340            let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
341            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
342            let mut memory = Memory::new(&mut heap).unwrap();
343
344            primitive_set.operate(&mut memory, 1).unwrap();
345
346            assert_eq!(primitive_set.find_free(), None);
347
348            primitive_set.collect_garbages(&memory);
349
350            assert_eq!(primitive_set.find_free(), None);
351        }
352    }
353}