sim_lib_numbers_tensor/implementation/
citizen.rs1use std::sync::Arc;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use sim_citizen::{CitizenField, arity_error, decode_version};
8use sim_kernel::{
9 Args, Callable, Class, ClassId, ClassRef, Cx, DefaultFactory, Error, Expr, Factory, Linker,
10 Object, ReadConstructor, ReadConstructorRef, Result, ShapeRef, Symbol, TableRef, Value,
11 force_list_to_vec,
12};
13use sim_lib_numbers_core::domains;
14
15use super::{domain::number_domain, value::build_tensor_value};
16
17pub fn tensor_value_class_symbol() -> Symbol {
20 domains::tensor_value_class()
21}
22
23fn value_shape_symbol() -> Symbol {
24 sim_lib_numbers_core::value_shape_symbol(&number_domain())
25}
26
27struct TensorValueClass {
28 id: AtomicU32,
29}
30
31impl TensorValueClass {
32 fn new() -> Self {
33 Self {
34 id: AtomicU32::new(0),
35 }
36 }
37
38 fn set_id(&self, id: ClassId) {
39 self.id.store(id.0, Ordering::Relaxed);
40 }
41}
42
43impl Object for TensorValueClass {
44 fn display(&self, _cx: &mut Cx) -> Result<String> {
45 Ok(format!("#<class {}>", tensor_value_class_symbol()))
46 }
47
48 fn as_any(&self) -> &dyn std::any::Any {
49 self
50 }
51}
52
53impl sim_kernel::ObjectCompat for TensorValueClass {
54 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
55 if let Some(value) = cx
56 .registry()
57 .class_by_symbol(&Symbol::qualified("core", "Class"))
58 {
59 return Ok(value.clone());
60 }
61 DefaultFactory.class_stub(
62 sim_kernel::CORE_CLASS_CLASS_ID,
63 Symbol::qualified("core", "Class"),
64 )
65 }
66
67 fn as_expr(&self, _cx: &mut Cx) -> Result<Expr> {
68 Ok(Expr::Symbol(tensor_value_class_symbol()))
69 }
70
71 fn as_callable(&self) -> Option<&dyn Callable> {
72 Some(self)
73 }
74
75 fn as_class(&self) -> Option<&dyn Class> {
76 Some(self)
77 }
78
79 fn as_read_constructor(&self) -> Option<&dyn ReadConstructor> {
80 Some(self)
81 }
82}
83
84impl Callable for TensorValueClass {
85 fn call(&self, cx: &mut Cx, args: Args) -> Result<Value> {
86 let values = args.into_vec();
87 let [version, shape, data, domain] = values.as_slice() else {
88 return Err(arity_error(tensor_value_class_symbol(), 4, values.len()));
89 };
90 decode_version(cx, version.clone(), 1, tensor_value_class_symbol())?;
91 let shape = Vec::<usize>::decode_field_value(cx, shape.clone(), "shape")?;
92 let data = decode_data(cx, data)?;
93 let domain = decode_domain(cx, domain)?;
94 if domain == number_domain() {
95 return Err(Error::Eval(
96 "numbers/Tensor domain field must name a scalar number domain".to_owned(),
97 ));
98 }
99 build_tensor_value(cx, shape, Some(domain), data)
100 }
101}
102
103impl Class for TensorValueClass {
104 fn id(&self) -> ClassId {
105 ClassId(self.id.load(Ordering::Relaxed))
106 }
107
108 fn symbol(&self) -> Symbol {
109 tensor_value_class_symbol()
110 }
111
112 fn constructor_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
113 cx.factory().nil()
114 }
115
116 fn instance_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
117 Ok(cx
118 .registry()
119 .shape_by_symbol(&value_shape_symbol())
120 .cloned()
121 .unwrap_or(cx.factory().symbol(value_shape_symbol())?))
122 }
123
124 fn read_constructor(&self, cx: &mut Cx) -> Result<Option<ReadConstructorRef>> {
125 Ok(cx
126 .registry()
127 .class_by_symbol(&tensor_value_class_symbol())
128 .cloned())
129 }
130
131 fn members(&self, cx: &mut Cx) -> Result<TableRef> {
132 cx.factory().table(vec![
133 (
134 Symbol::new("version"),
135 cx.factory()
136 .number_literal(Symbol::qualified("citizen", "int"), "1".to_owned())?,
137 ),
138 (
139 Symbol::new("fields"),
140 cx.factory().list(vec![
141 cx.factory().symbol(Symbol::new("shape"))?,
142 cx.factory().symbol(Symbol::new("data"))?,
143 cx.factory().symbol(Symbol::new("domain"))?,
144 ])?,
145 ),
146 ])
147 }
148}
149
150impl ReadConstructor for TensorValueClass {
151 fn symbol(&self) -> Symbol {
152 tensor_value_class_symbol()
153 }
154
155 fn args_shape(&self, cx: &mut Cx) -> Result<ShapeRef> {
156 cx.factory().nil()
157 }
158
159 fn construct_read(&self, cx: &mut Cx, args: Vec<Value>) -> Result<Value> {
160 if args.len() != 4 {
161 return Err(arity_error(tensor_value_class_symbol(), 4, args.len()));
162 }
163 self.call(cx, Args::new(args))
164 }
165}
166
167fn decode_data(cx: &mut Cx, value: &Value) -> Result<Vec<Value>> {
168 let list = value
169 .object()
170 .as_list()
171 .ok_or_else(|| Error::Eval("numbers/Tensor data field must be a list".to_owned()))?;
172 force_list_to_vec(cx, list, "numbers/Tensor data")
173}
174
175fn decode_domain(cx: &mut Cx, value: &Value) -> Result<Symbol> {
176 match value.object().as_expr(cx)? {
177 Expr::Symbol(symbol) => Ok(symbol),
178 _ => Err(Error::Eval(
179 "numbers/Tensor domain field must be a symbol".to_owned(),
180 )),
181 }
182}
183
184pub(crate) fn register_tensor_value_class(linker: &mut Linker<'_>) -> Result<()> {
185 let class = Arc::new(TensorValueClass::new());
186 let id = linker.class_value(
187 tensor_value_class_symbol(),
188 DefaultFactory
189 .opaque(class.clone())
190 .expect("tensor value class should be boxable"),
191 )?;
192 class.set_id(id);
193 Ok(())
194}
195
196fn install_tensor_value_citizen(linker: &mut Linker<'_>) -> Result<()> {
197 register_tensor_value_class(linker)
198}
199
200fn conformance_tensor_value_citizen(cx: &mut Cx) -> Result<()> {
201 let dtype = domains::i64();
202 let value = build_tensor_value(
203 cx,
204 vec![2],
205 Some(dtype.clone()),
206 vec![i64_cell("1")?, i64_cell("2")?],
207 )?;
208 sim_citizen::check_value_fixture_with_wrong_version(
209 cx,
210 value,
211 Some(vec![
212 Expr::Symbol(Symbol::new("v999")),
213 Expr::List(vec![int_expr("2")]),
214 Expr::List(vec![
215 Expr::Number(sim_kernel::NumberLiteral {
216 domain: dtype.clone(),
217 canonical: "1".to_owned(),
218 }),
219 Expr::Number(sim_kernel::NumberLiteral {
220 domain: dtype.clone(),
221 canonical: "2".to_owned(),
222 }),
223 ]),
224 Expr::Symbol(dtype),
225 ]),
226 )
227}
228
229fn i64_cell(canonical: &str) -> Result<Value> {
230 DefaultFactory.number_literal(domains::i64(), canonical.to_owned())
231}
232
233fn int_expr(canonical: &str) -> Expr {
234 Expr::Number(sim_kernel::NumberLiteral {
235 domain: Symbol::qualified("citizen", "int"),
236 canonical: canonical.to_owned(),
237 })
238}
239
240sim_citizen::inventory::submit! {
241 sim_citizen::CitizenInfo {
242 symbol: "numbers/Tensor",
243 version: 1,
244 crate_name: env!("CARGO_PKG_NAME"),
245 arity: 3,
246 install: install_tensor_value_citizen,
247 conformance: conformance_tensor_value_citizen,
248 }
249}