Skip to main content

sim_lib_numbers_tensor/implementation/
citizen.rs

1//! The tensor value class as a runtime citizen: its class registration and the
2//! read-constructor that reconstructs tensor values from encoded form.
3
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use sim_citizen::{CitizenField, arity_error, decode_version};
8use sim_kernel::{
9    Args, Callable, Class, ClassId, ClassRef, Cx, DefaultFactory, Error, Expr, Factory, Linker,
10    Object, ReadConstructor, ReadConstructorRef, Result, ShapeRef, Symbol, TableRef, Value,
11    force_list_to_vec,
12};
13use sim_lib_numbers_core::domains;
14
15use super::{domain::number_domain, value::build_tensor_value};
16
17/// The symbol naming the tensor value class (`numbers/Tensor`) under which
18/// tensor values register and reconstruct.
19pub fn tensor_value_class_symbol() -> Symbol {
20    domains::tensor_value_class()
21}
22
23fn value_shape_symbol() -> Symbol {
24    sim_lib_numbers_core::value_shape_symbol(&number_domain())
25}
26
27struct TensorValueClass {
28    id: AtomicU32,
29}
30
31impl TensorValueClass {
32    fn new() -> Self {
33        Self {
34            id: AtomicU32::new(0),
35        }
36    }
37
38    fn set_id(&self, id: ClassId) {
39        self.id.store(id.0, Ordering::Relaxed);
40    }
41}
42
43impl Object for TensorValueClass {
44    fn display(&self, _cx: &mut Cx) -> Result<String> {
45        Ok(format!("#<class {}>", tensor_value_class_symbol()))
46    }
47
48    fn as_any(&self) -> &dyn std::any::Any {
49        self
50    }
51}
52
53impl sim_kernel::ObjectCompat for TensorValueClass {
54    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
55        if let Some(value) = cx
56            .registry()
57            .class_by_symbol(&Symbol::qualified("core", "Class"))
58        {
59            return Ok(value.clone());
60        }
61        DefaultFactory.class_stub(
62            sim_kernel::CORE_CLASS_CLASS_ID,
63            Symbol::qualified("core", "Class"),
64        )
65    }
66
67    fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
68        Ok(Expr::Symbol(tensor_value_class_symbol()))
69    }
70
71    fn as_callable(&self) -> Option<&dyn Callable> {
72        Some(self)
73    }
74
75    fn as_class(&self) -> Option<&dyn Class> {
76        Some(self)
77    }
78
79    fn as_read_constructor(&self) -> Option<&dyn ReadConstructor> {
80        Some(self)
81    }
82}
83
84impl Callable for TensorValueClass {
85    fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
86        let values = args.into_vec();
87        let [version, shape, data, domain] = values.as_slice() else {
88            return Err(arity_error(tensor_value_class_symbol(), 4, values.len()));
89        };
90        decode_version(cx, version.clone(), 1, tensor_value_class_symbol())?;
91        let shape = Vec::<usize>::decode_field_value(cx, shape.clone(), "shape")?;
92        let data = decode_data(cx, data)?;
93        let domain = decode_domain(cx, domain)?;
94        if domain == number_domain() {
95            return Err(Error::Eval(
96                "numbers/Tensor domain field must name a scalar number domain".to_owned(),
97            ));
98        }
99        build_tensor_value(cx, shape, Some(domain), data)
100    }
101}
102
103impl Class for TensorValueClass {
104    fn id(&self) -> ClassId {
105        ClassId(self.id.load(Ordering::Relaxed))
106    }
107
108    fn symbol(&self) -> Symbol {
109        tensor_value_class_symbol()
110    }
111
112    fn constructor_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
113        cx.factory().nil()
114    }
115
116    fn instance_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
117        Ok(cx
118            .registry()
119            .shape_by_symbol(&value_shape_symbol())
120            .cloned()
121            .unwrap_or(cx.factory().symbol(value_shape_symbol())?))
122    }
123
124    fn read_constructor(&self, cx: &mut Cx) -> Result<Option<ReadConstructorRef>> {
125        Ok(cx
126            .registry()
127            .class_by_symbol(&tensor_value_class_symbol())
128            .cloned())
129    }
130
131    fn members(&self, cx: &mut Cx) -> Result<TableRef> {
132        cx.factory().table(vec![
133            (
134                Symbol::new("version"),
135                cx.factory()
136                    .number_literal(Symbol::qualified("citizen", "int"), "1".to_owned())?,
137            ),
138            (
139                Symbol::new("fields"),
140                cx.factory().list(vec![
141                    cx.factory().symbol(Symbol::new("shape"))?,
142                    cx.factory().symbol(Symbol::new("data"))?,
143                    cx.factory().symbol(Symbol::new("domain"))?,
144                ])?,
145            ),
146        ])
147    }
148}
149
150impl ReadConstructor for TensorValueClass {
151    fn symbol(&self) -> Symbol {
152        tensor_value_class_symbol()
153    }
154
155    fn args_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
156        cx.factory().nil()
157    }
158
159    fn construct_read(&self, cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
160        if args.len() != 4 {
161            return Err(arity_error(tensor_value_class_symbol(), 4, args.len()));
162        }
163        self.call(cx, Args::new(args))
164    }
165}
166
167fn decode_data(cx: &mut Cx, value: &Value) -> Result<Vec<Value>> {
168    let list = value
169        .object()
170        .as_list()
171        .ok_or_else(|| Error::Eval("numbers/Tensor data field must be a list".to_owned()))?;
172    force_list_to_vec(cx, list, "numbers/Tensor data")
173}
174
175fn decode_domain(cx: &mut Cx, value: &Value) -> Result<Symbol> {
176    match value.object().as_expr(cx)? {
177        Expr::Symbol(symbol) => Ok(symbol),
178        _ => Err(Error::Eval(
179            "numbers/Tensor domain field must be a symbol".to_owned(),
180        )),
181    }
182}
183
184pub(crate) fn register_tensor_value_class(linker: &mut Linker<'_>) -> Result<()> {
185    let class = Arc::new(TensorValueClass::new());
186    let id = linker.class_value(
187        tensor_value_class_symbol(),
188        DefaultFactory
189            .opaque(class.clone())
190            .expect("tensor value class should be boxable"),
191    )?;
192    class.set_id(id);
193    Ok(())
194}
195
196fn install_tensor_value_citizen(linker: &mut Linker<'_>) -> Result<()> {
197    register_tensor_value_class(linker)
198}
199
200fn conformance_tensor_value_citizen(cx: &mut Cx) -> Result<()> {
201    let dtype = domains::i64();
202    let value = build_tensor_value(
203        cx,
204        vec![2],
205        Some(dtype.clone()),
206        vec![i64_cell("1")?, i64_cell("2")?],
207    )?;
208    sim_citizen::check_value_fixture_with_wrong_version(
209        cx,
210        value,
211        Some(vec![
212            Expr::Symbol(Symbol::new("v999")),
213            Expr::List(vec![int_expr("2")]),
214            Expr::List(vec![
215                Expr::Number(sim_kernel::NumberLiteral {
216                    domain: dtype.clone(),
217                    canonical: "1".to_owned(),
218                }),
219                Expr::Number(sim_kernel::NumberLiteral {
220                    domain: dtype.clone(),
221                    canonical: "2".to_owned(),
222                }),
223            ]),
224            Expr::Symbol(dtype),
225        ]),
226    )
227}
228
229fn i64_cell(canonical: &str) -> Result<Value> {
230    DefaultFactory.number_literal(domains::i64(), canonical.to_owned())
231}
232
233fn int_expr(canonical: &str) -> Expr {
234    Expr::Number(sim_kernel::NumberLiteral {
235        domain: Symbol::qualified("citizen", "int"),
236        canonical: canonical.to_owned(),
237    })
238}
239
240sim_citizen::inventory::submit! {
241    sim_citizen::CitizenInfo {
242        symbol: "numbers/Tensor",
243        version: 1,
244        crate_name: env!("CARGO_PKG_NAME"),
245        arity: 3,
246        install: install_tensor_value_citizen,
247        conformance: conformance_tensor_value_citizen,
248    }
249}