stak_dynamic/
primitive_set.rs

1use crate::{error::DynamicError, scheme_value::SchemeValue};
2use alloc::{boxed::Box, string::String, vec, vec::Vec};
3use any_fn::AnyFn;
4use bitvec::bitvec;
5use core::any::TypeId;
6use stak_vm::{Cons, Error, Memory, Number, PrimitiveSet, Type, Value};
7
8const MAXIMUM_ARGUMENT_COUNT: usize = 16;
9
10type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
11type SchemeType = (
12    TypeId,
13    Box<dyn Fn(&Memory, Value) -> Option<any_fn::Value>>,
14    Box<dyn Fn(&mut Memory, any_fn::Value) -> Result<Value, DynamicError>>,
15);
16
17/// A dynamic primitive set equipped with native functions in Rust.
18pub struct DynamicPrimitiveSet<'a, 'b> {
19    functions: &'a mut [AnyFn<'b>],
20    types: Vec<SchemeType>,
21    values: Vec<Option<any_fn::Value>>,
22}
23
24impl<'a, 'b> DynamicPrimitiveSet<'a, 'b> {
25    /// Creates a primitive set.
26    pub fn new(functions: &'a mut [AnyFn<'b>]) -> Self {
27        let mut set = Self {
28            functions,
29            types: vec![],
30            values: vec![],
31        };
32
33        set.register_type::<bool>();
34        set.register_type::<i8>();
35        set.register_type::<u8>();
36        set.register_type::<i16>();
37        set.register_type::<u16>();
38        set.register_type::<i32>();
39        set.register_type::<u32>();
40        set.register_type::<i64>();
41        set.register_type::<u64>();
42        set.register_type::<f32>();
43        set.register_type::<f64>();
44        set.register_type::<isize>();
45        set.register_type::<usize>();
46        set.register_type::<String>();
47
48        set
49    }
50
51    /// Registers a type compatible between Scheme and Rust.
52    ///
53    /// Values of such types are automatically marshalled when we pass them from
54    /// Scheme to Rust, and vice versa. Marshalling values can lead to the loss
55    /// of information (e.g. floating-point numbers in Scheme marshalled
56    /// into integers in Rust.)
57    pub fn register_type<T: SchemeValue + 'static>(&mut self) {
58        self.types.push((
59            TypeId::of::<T>(),
60            Box::new(|memory, value| T::from_scheme(memory, value).map(any_fn::value)),
61            Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
62        ));
63    }
64
65    fn collect_garbages(&mut self, memory: &Memory) {
66        let mut marks = bitvec![0; self.values.len()];
67
68        for index in 0..(memory.allocation_index() / 2) {
69            let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
70
71            if memory.cdr(cons).tag() != Type::Foreign as _ {
72                continue;
73            }
74
75            let index = memory.car(cons).assume_number().to_i64() as _;
76
77            // Be conservative as foreign type tags can be used for something else.
78            if index >= self.values.len() {
79                continue;
80            }
81
82            marks.set(index, true);
83        }
84
85        for (index, mark) in marks.into_iter().enumerate() {
86            if !mark {
87                self.values[index] = None;
88            }
89        }
90    }
91
92    // TODO Optimize this with `BitSlice::first_zero()`.
93    fn find_free(&self) -> Option<usize> {
94        self.values.iter().position(Option::is_none)
95    }
96
97    fn allocate(&mut self, memory: &Memory) -> usize {
98        if let Some(index) = self.find_free() {
99            index
100        } else if let Some(index) = {
101            self.collect_garbages(memory);
102            self.find_free()
103        } {
104            index
105        } else {
106            self.values.push(None);
107            self.values.len() - 1
108        }
109    }
110
111    fn convert_from_scheme(
112        &self,
113        memory: &Memory,
114        value: Value,
115        type_id: TypeId,
116    ) -> Option<any_fn::Value> {
117        for (id, from, _) in &self.types {
118            if type_id == *id {
119                return from(memory, value);
120            }
121        }
122
123        None
124    }
125
126    fn convert_into_scheme(
127        &mut self,
128        memory: &mut Memory,
129        value: any_fn::Value,
130    ) -> Result<Value, DynamicError> {
131        for (id, _, into) in &self.types {
132            if value.type_id()? == *id {
133                return into(memory, value);
134            }
135        }
136
137        let index = self.allocate(memory);
138
139        self.values[index] = Some(value);
140
141        Ok(memory
142            .allocate(
143                Number::from_i64(index as _).into(),
144                memory.null().set_tag(Type::Foreign as _).into(),
145            )?
146            .into())
147    }
148}
149
150impl PrimitiveSet for DynamicPrimitiveSet<'_, '_> {
151    type Error = DynamicError;
152
153    fn operate(&mut self, memory: &mut Memory, primitive: usize) -> Result<(), Self::Error> {
154        let function = self
155            .functions
156            .get(primitive)
157            .ok_or(Error::IllegalPrimitive)?;
158
159        let mut arguments = (0..function.arity())
160            .map(|_| memory.pop())
161            .collect::<ArgumentVec<_>>();
162        arguments.reverse();
163
164        let cloned_arguments = {
165            arguments
166                .iter()
167                .enumerate()
168                .map(|(index, &value)| {
169                    self.convert_from_scheme(memory, value, function.parameter_types()[index])
170                })
171                .collect::<ArgumentVec<_>>()
172        };
173
174        let mut copied_arguments = ArgumentVec::new();
175
176        for &value in &arguments {
177            let value = if value.is_cons() && memory.cdr_value(value).tag() == Type::Foreign as _ {
178                Some(
179                    self.values
180                        .get(memory.car_value(value).assume_number().to_i64() as usize)
181                        .ok_or(DynamicError::ValueIndex)?
182                        .as_ref()
183                        .ok_or(DynamicError::ValueIndex)?,
184                )
185            } else {
186                None
187            };
188
189            copied_arguments
190                .push(value)
191                .map_err(|_| Error::ArgumentCount)?;
192        }
193
194        let value = self
195            .functions
196            .get_mut(primitive)
197            .ok_or(Error::IllegalPrimitive)?
198            .call(
199                copied_arguments
200                    .into_iter()
201                    .enumerate()
202                    .map(|(index, value)| {
203                        cloned_arguments[index]
204                            .as_ref()
205                            .map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
206                    })
207                    .collect::<Result<ArgumentVec<_>, DynamicError>>()?
208                    .as_slice(),
209            )?;
210
211        let value = self.convert_into_scheme(memory, value)?;
212        memory.push(value)?;
213
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use any_fn::{Ref, r#fn, value};
222
223    const HEAP_SIZE: usize = 1 << 8;
224
225    struct Foo {
226        bar: usize,
227    }
228
229    impl Foo {
230        const fn new(bar: usize) -> Self {
231            Self { bar }
232        }
233
234        const fn bar(&self) -> usize {
235            self.bar
236        }
237
238        fn baz(&mut self, value: usize) {
239            self.bar += value;
240        }
241    }
242
243    fn invalidate_foreign_values(memory: &mut Memory) {
244        for index in 0..(memory.size() / 2) {
245            let cons = Cons::new((2 * index) as _);
246
247            if memory.cdr(cons).tag() == Type::Foreign as _ {
248                memory.set_car(cons, Number::from_i64(1 << 16).into());
249            }
250        }
251    }
252
253    #[test]
254    fn create() {
255        let mut functions = [
256            r#fn(Foo::new),
257            r#fn::<(Ref<_>,), _>(Foo::bar),
258            r#fn(Foo::baz),
259        ];
260
261        DynamicPrimitiveSet::new(&mut functions);
262    }
263
264    #[test]
265    fn allocate_two() {
266        let mut heap = [Default::default(); HEAP_SIZE];
267        let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
268        let mut memory = Memory::new(&mut heap).unwrap();
269
270        let index = primitive_set.allocate(&memory);
271        primitive_set.values[index] = Some(value(42usize));
272        assert_eq!(index, 0);
273        assert_eq!(primitive_set.find_free(), None);
274
275        let cons = memory
276            .allocate(
277                Number::from_i64(index as _).into(),
278                memory.null().set_tag(Type::Foreign as _).into(),
279            )
280            .unwrap();
281        memory.push(cons.into()).unwrap();
282
283        let index = primitive_set.allocate(&memory);
284        primitive_set.values[index] = Some(value(42usize));
285        assert_eq!(index, 1);
286        assert_eq!(primitive_set.find_free(), None);
287    }
288
289    mod garbage_collection {
290        use super::*;
291
292        #[test]
293        fn collect_none() {
294            let mut heap = [Default::default(); HEAP_SIZE];
295            let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
296
297            primitive_set.collect_garbages(&Memory::new(&mut heap).unwrap());
298        }
299
300        #[test]
301        fn collect_one() {
302            let mut heap = [Default::default(); HEAP_SIZE];
303            let mut functions = [r#fn(|| Foo { bar: 42 })];
304            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
305            let mut memory = Memory::new(&mut heap).unwrap();
306
307            primitive_set.operate(&mut memory, 0).unwrap();
308
309            assert_eq!(primitive_set.find_free(), None);
310
311            invalidate_foreign_values(&mut memory);
312
313            primitive_set.collect_garbages(&memory);
314
315            assert_eq!(primitive_set.find_free(), Some(0));
316        }
317
318        #[test]
319        fn keep_one() {
320            let mut heap = [Default::default(); HEAP_SIZE];
321            let mut functions = [r#fn(|| Foo { bar: 42 })];
322            let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
323            let mut memory = Memory::new(&mut heap).unwrap();
324
325            primitive_set.operate(&mut memory, 0).unwrap();
326
327            assert_eq!(primitive_set.find_free(), None);
328
329            primitive_set.collect_garbages(&memory);
330
331            assert_eq!(primitive_set.find_free(), None);
332        }
333    }
334}