1use crate::{error::DynamicError, scheme_value::SchemeValue};
2use alloc::{boxed::Box, string::String, vec, vec::Vec};
3use any_fn::AnyFn;
4use bitvec::bitvec;
5use core::any::TypeId;
6use stak_vm::{Cons, Error, Memory, Number, PrimitiveSet, Type, Value};
7use winter_maybe_async::maybe_async;
8
9const MAXIMUM_ARGUMENT_COUNT: usize = 16;
10
11type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
12type SchemeType = (
13 TypeId,
14 Box<dyn Fn(&Memory, Value) -> Option<any_fn::Value>>,
15 Box<dyn Fn(&mut Memory, any_fn::Value) -> Result<Value, DynamicError>>,
16);
17
18pub struct DynamicPrimitiveSet<'a, 'b> {
20 functions: &'a mut [(&'a str, AnyFn<'b>)],
21 types: Vec<SchemeType>,
22 values: Vec<Option<any_fn::Value>>,
23}
24
25impl<'a, 'b> DynamicPrimitiveSet<'a, 'b> {
26 pub fn new(functions: &'a mut [(&'a str, AnyFn<'b>)]) -> Self {
28 let mut set = Self {
29 functions,
30 types: vec![],
31 values: vec![],
32 };
33
34 set.register_type::<bool>();
36 set.register_type::<i8>();
37 set.register_type::<u8>();
38 set.register_type::<i16>();
39 set.register_type::<u16>();
40 set.register_type::<i32>();
41 set.register_type::<u32>();
42 set.register_type::<i64>();
43 set.register_type::<u64>();
44 set.register_type::<f32>();
45 set.register_type::<f64>();
46 set.register_type::<isize>();
47 set.register_type::<usize>();
48 set.register_type::<String>();
49
50 set
51 }
52
53 pub fn register_type<T: SchemeValue + 'static>(&mut self) {
60 self.types.push((
61 TypeId::of::<T>(),
62 Box::new(|memory, value| T::from_scheme(memory, value).map(any_fn::value)),
63 Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
64 ));
65 }
66
67 fn collect_garbages(&mut self, memory: &Memory) {
68 let mut marks = bitvec![0; self.values.len()];
69
70 for index in 0..(memory.allocation_index() / 2) {
71 let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
72
73 if memory.cdr(cons).tag() != Type::Foreign as _ {
74 continue;
75 }
76
77 marks.set(memory.car(cons).assume_number().to_i64() as _, true);
78 }
79
80 for (index, mark) in marks.into_iter().enumerate() {
81 if !mark {
82 self.values[index] = None;
83 }
84 }
85 }
86
87 fn find_free(&self) -> Option<usize> {
89 self.values.iter().position(Option::is_none)
90 }
91
92 fn allocate(&mut self, memory: &Memory) -> usize {
93 if let Some(index) = self.find_free() {
94 index
95 } else if let Some(index) = {
96 self.collect_garbages(memory);
97 self.find_free()
98 } {
99 index
100 } else {
101 self.values.push(None);
102 self.values.len() - 1
103 }
104 }
105
106 fn convert_from_scheme(
107 &self,
108 memory: &Memory,
109 value: Value,
110 type_id: TypeId,
111 ) -> Option<any_fn::Value> {
112 for (id, from, _) in &self.types {
113 if type_id == *id {
114 return from(memory, value);
115 }
116 }
117
118 None
119 }
120
121 fn convert_into_scheme(
122 &mut self,
123 memory: &mut Memory,
124 value: any_fn::Value,
125 ) -> Result<Value, DynamicError> {
126 for (id, _, into) in &self.types {
127 if value.type_id()? == *id {
128 return into(memory, value);
129 }
130 }
131
132 let index = self.allocate(memory);
133
134 self.values[index] = Some(value);
135
136 Ok(memory
137 .allocate(
138 Number::from_i64(index as _).into(),
139 memory.null().set_tag(Type::Foreign as _).into(),
140 )?
141 .into())
142 }
143}
144
145impl PrimitiveSet for DynamicPrimitiveSet<'_, '_> {
146 type Error = DynamicError;
147
148 #[maybe_async]
149 fn operate(&mut self, memory: &mut Memory<'_>, primitive: usize) -> Result<(), Self::Error> {
150 if primitive == 0 {
151 memory.set_register(memory.null());
152
153 for (name, _) in self.functions.iter().rev() {
154 let list = memory.cons(memory.null().into(), memory.register())?;
155 memory.set_register(list);
156 let string = memory.build_raw_string(name)?;
157 memory.set_car(memory.register(), string.into());
158 }
159
160 memory.push(memory.register().into())?;
161
162 Ok(())
163 } else {
164 let primitive = primitive - 1;
165 let (_, function) = self
166 .functions
167 .get(primitive)
168 .ok_or(Error::IllegalPrimitive)?;
169
170 let mut arguments = (0..function.arity())
171 .map(|_| memory.pop())
172 .collect::<ArgumentVec<_>>();
173 arguments.reverse();
174
175 let cloned_arguments = {
176 arguments
177 .iter()
178 .enumerate()
179 .map(|(index, &value)| {
180 self.convert_from_scheme(memory, value, function.parameter_types()[index])
181 })
182 .collect::<ArgumentVec<_>>()
183 };
184
185 let mut copied_arguments = ArgumentVec::new();
186
187 for &value in &arguments {
188 let value =
189 if value.is_cons() && memory.cdr_value(value).tag() == Type::Foreign as _ {
190 Some(
191 self.values
192 .get(memory.car_value(value).assume_number().to_i64() as usize)
193 .ok_or(DynamicError::ValueIndex)?
194 .as_ref()
195 .ok_or(DynamicError::ValueIndex)?,
196 )
197 } else {
198 None
199 };
200
201 copied_arguments
202 .push(value)
203 .map_err(|_| Error::ArgumentCount)?;
204 }
205
206 let value = self
207 .functions
208 .get_mut(primitive)
209 .ok_or(Error::IllegalPrimitive)?
210 .1
211 .call(
212 copied_arguments
213 .into_iter()
214 .enumerate()
215 .map(|(index, value)| {
216 cloned_arguments[index]
217 .as_ref()
218 .map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
219 })
220 .collect::<Result<ArgumentVec<_>, DynamicError>>()?
221 .as_slice(),
222 )?;
223
224 let value = self.convert_into_scheme(memory, value)?;
225 memory.push(value)?;
226
227 Ok(())
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use any_fn::{Ref, r#fn, value};
236 use winter_maybe_async::maybe_await;
237
238 const HEAP_SIZE: usize = 1 << 8;
239
240 struct Foo {
241 bar: usize,
242 }
243
244 impl Foo {
245 const fn new(bar: usize) -> Self {
246 Self { bar }
247 }
248
249 const fn bar(&self) -> usize {
250 self.bar
251 }
252
253 fn baz(&mut self, value: usize) {
254 self.bar += value;
255 }
256 }
257
258 #[test]
259 fn create() {
260 let mut functions = [
261 ("make-foo", r#fn(Foo::new)),
262 ("foo-bar", r#fn::<(Ref<_>,), _>(Foo::bar)),
263 ("foo-baz", r#fn(Foo::baz)),
264 ];
265
266 DynamicPrimitiveSet::new(&mut functions);
267 }
268
269 #[test]
270 fn allocate_two() {
271 let mut heap = [Default::default(); HEAP_SIZE];
272 let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
273 let mut memory = Memory::new(&mut heap).unwrap();
274
275 let index = primitive_set.allocate(&memory);
276 primitive_set.values[index] = Some(value(42usize));
277 assert_eq!(index, 0);
278 assert_eq!(primitive_set.find_free(), None);
279
280 let cons = memory
281 .allocate(
282 Number::from_i64(index as _).into(),
283 memory.null().set_tag(Type::Foreign as _).into(),
284 )
285 .unwrap();
286 memory.push(cons.into()).unwrap();
287
288 let index = primitive_set.allocate(&memory);
289 primitive_set.values[index] = Some(value(42usize));
290 assert_eq!(index, 1);
291 assert_eq!(primitive_set.find_free(), None);
292 }
293
294 mod garbage_collection {
295 use super::*;
296
297 #[test]
298 fn collect_none() {
299 let mut heap = [Default::default(); HEAP_SIZE];
300 let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
301
302 primitive_set.collect_garbages(&Memory::new(&mut heap).unwrap());
303 }
304
305 #[tokio::test]
306 async fn collect_one() {
307 let mut heap = [Default::default(); HEAP_SIZE];
308 let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
309 let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
310 let mut memory = Memory::new(&mut heap).unwrap();
311
312 maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
313
314 assert_eq!(primitive_set.find_free(), None);
315
316 memory.pop();
318 memory.collect_garbages(None).unwrap();
319
320 primitive_set.collect_garbages(&memory);
321
322 assert_eq!(primitive_set.find_free(), Some(0));
323 }
324
325 #[tokio::test]
326 async fn keep_one() {
327 let mut heap = [Default::default(); HEAP_SIZE];
328 let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
329 let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
330 let mut memory = Memory::new(&mut heap).unwrap();
331
332 maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
333
334 assert_eq!(primitive_set.find_free(), None);
335
336 primitive_set.collect_garbages(&memory);
337
338 assert_eq!(primitive_set.find_free(), None);
339 }
340 }
341}