sim_lib_numbers_tensor/implementation/
value.rs1use std::cmp::Ordering;
5use std::collections::{BTreeMap, BinaryHeap};
6use std::sync::Arc;
7
8use sim_kernel::{
9 ClassRef, Cx, DefaultFactory, Error, Expr, Factory, NumberValue, Object, ObjectCompat,
10 ObjectEncode, ObjectEncoding, Result, Symbol, Value,
11};
12
13use super::citizen::tensor_value_class_symbol;
14use super::domain::number_domain;
15
16#[derive(Clone)]
24pub struct Tensor {
25 pub shape: Vec<usize>,
27 pub dtype: Symbol,
30 pub data: Vec<Value>,
33}
34
35impl Tensor {
36 pub fn rank(&self) -> usize {
39 self.shape.len()
40 }
41
42 pub fn flat_offset(shape: &[usize], indices: &[usize]) -> Result<usize> {
61 if shape.len() != indices.len() {
62 return Err(Error::Eval("tensor index rank mismatch".to_owned()));
63 }
64 let mut stride = 1usize;
65 let mut offset = 0usize;
66 for (dim, index) in shape.iter().rev().zip(indices.iter().rev()) {
67 if *index >= *dim {
68 return Err(Error::Eval("tensor index was out of bounds".to_owned()));
69 }
70 offset += index * stride;
71 stride = stride.saturating_mul(*dim);
72 }
73 Ok(offset)
74 }
75
76 pub fn coordinates(shape: &[usize]) -> Vec<Vec<usize>> {
79 if shape.is_empty() {
80 return vec![Vec::new()];
81 }
82 let mut out = Vec::new();
83 let mut coord = vec![0usize; shape.len()];
84 loop {
85 out.push(coord.clone());
86 let mut axis = shape.len();
87 while axis > 0 {
88 axis -= 1;
89 coord[axis] += 1;
90 if coord[axis] < shape[axis] {
91 break;
92 }
93 coord[axis] = 0;
94 if axis == 0 {
95 return out;
96 }
97 }
98 }
99 }
100}
101
102impl Object for Tensor {
103 fn display(&self, cx: &mut Cx) -> Result<String> {
104 match self.as_expr(cx)? {
105 Expr::Call { .. } => Ok(format!("{}<{:?}>", tensor_display_name(), self.shape)),
106 expr => Ok(format!("{expr:?}")),
107 }
108 }
109
110 fn as_any(&self) -> &dyn std::any::Any {
111 self
112 }
113}
114
115impl sim_kernel::ObjectCompat for Tensor {
116 fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
117 if let Some(value) = cx.registry().class_by_symbol(&tensor_value_class_symbol()) {
118 return Ok(value.clone());
119 }
120 if let Some(value) = cx
121 .registry()
122 .class_by_symbol(&Symbol::qualified("core", "Number"))
123 {
124 return Ok(value.clone());
125 }
126 DefaultFactory.class_stub(
127 sim_kernel::CORE_NUMBER_CLASS_ID,
128 Symbol::qualified("core", "Number"),
129 )
130 }
131 fn as_expr(&self, cx: &mut Cx) -> Result<Expr> {
132 match self.rank() {
133 0 => Ok(Expr::Call {
134 operator: Box::new(Expr::Symbol(Symbol::new("scalar"))),
135 args: vec![self.data[0].object().as_expr(cx)?],
136 }),
137 1 => Ok(Expr::Vector(exprs(cx, &self.data)?)),
138 2 => {
139 let width = self.shape[1];
140 let rows = self
141 .data
142 .chunks(width)
143 .map(|row| exprs(cx, row).map(Expr::Vector))
144 .collect::<Result<Vec<_>>>()?;
145 Ok(Expr::Vector(rows))
146 }
147 _ => Ok(Expr::Call {
148 operator: Box::new(Expr::Symbol(Symbol::new("tensor"))),
149 args: vec![
150 Expr::Vector(
151 self.shape
152 .iter()
153 .map(|dim| Expr::String(dim.to_string()))
154 .collect(),
155 ),
156 Expr::Symbol(self.dtype.clone()),
157 Expr::Vector(exprs(cx, &self.data)?),
158 ],
159 }),
160 }
161 }
162 fn as_table(&self, cx: &mut Cx) -> Result<Value> {
163 let shape = cx.factory().list(
164 self.shape
165 .iter()
166 .map(|dim| cx.factory().string(dim.to_string()))
167 .collect::<Result<Vec<_>>>()?,
168 )?;
169 let data = cx.factory().list(self.data.clone())?;
170 cx.factory().table(vec![
171 (
172 Symbol::new("kind"),
173 cx.factory().string("tensor".to_owned())?,
174 ),
175 (Symbol::new("shape"), shape),
176 (
177 Symbol::new("dtype"),
178 cx.factory().symbol(self.dtype.clone())?,
179 ),
180 (Symbol::new("data"), data),
181 ])
182 }
183 fn as_number_value(&self) -> Option<&dyn NumberValue> {
184 Some(self)
185 }
186
187 fn as_object_encoder(&self) -> Option<&dyn ObjectEncode> {
188 Some(self)
189 }
190}
191
192impl NumberValue for Tensor {
193 fn number_domain(&self, _cx: &mut Cx) -> Result<Symbol> {
194 Ok(number_domain())
195 }
196}
197
198impl ObjectEncode for Tensor {
199 fn object_encoding(&self, cx: &mut Cx) -> Result<ObjectEncoding> {
200 Ok(ObjectEncoding::Constructor {
201 class: tensor_value_class_symbol(),
202 args: vec![
203 Expr::Symbol(Symbol::new("v1")),
204 Expr::List(
205 self.shape
206 .iter()
207 .map(|dim| {
208 Expr::Number(sim_kernel::NumberLiteral {
209 domain: Symbol::qualified("citizen", "int"),
210 canonical: dim.to_string(),
211 })
212 })
213 .collect(),
214 ),
215 Expr::List(exprs(cx, &self.data)?),
216 Expr::Symbol(self.dtype.clone()),
217 ],
218 })
219 }
220}
221
222impl sim_citizen::Citizen for Tensor {
223 fn citizen_symbol() -> Symbol {
224 tensor_value_class_symbol()
225 }
226
227 fn citizen_version() -> u32 {
228 1
229 }
230
231 fn citizen_arity() -> usize {
232 3
233 }
234
235 fn citizen_fields() -> &'static [&'static str] {
236 &["shape", "data", "domain"]
237 }
238}
239
240pub fn build_tensor_value(
248 cx: &mut Cx,
249 shape: Vec<usize>,
250 dtype_hint: Option<Symbol>,
251 data: Vec<Value>,
252) -> Result<Value> {
253 let expected = checked_element_count(&shape)?;
254 if data.len() != expected {
255 return Err(Error::Eval(format!(
256 "tensor shape {:?} expects {expected} cells, found {}",
257 shape,
258 data.len()
259 )));
260 }
261 validate_cells(cx, &data)?;
262 let dtype = choose_dtype(cx, dtype_hint, &data)?;
263 cx.factory().opaque(Arc::new(Tensor { shape, dtype, data }))
264}
265
266pub fn build_scalar_tensor_value(cx: &mut Cx, value: Value) -> Result<Value> {
268 build_tensor_value(cx, Vec::new(), None, vec![value])
269}
270
271pub fn tensor_value_ref(value: &Value) -> Option<&Tensor> {
273 value.object().downcast_ref::<Tensor>()
274}
275
276pub fn tensor_dtype(tensor: &Tensor) -> &Symbol {
278 &tensor.dtype
279}
280
281pub fn flatten_tensor_scalar_cells(tensor: &Tensor) -> Vec<Value> {
283 tensor.data.clone()
284}
285
286pub fn tensor_display_name() -> &'static str {
287 "tensor"
288}
289
290fn exprs(cx: &mut Cx, data: &[Value]) -> Result<Vec<Expr>> {
291 data.iter()
292 .map(|value| value.object().as_expr(cx))
293 .collect()
294}
295
296use crate::spec::checked_element_count;
297
298fn validate_cells(cx: &mut Cx, data: &[Value]) -> Result<()> {
299 for cell in data {
300 let Some(number) = cx.number_value_ref(cell.clone())? else {
301 return Err(Error::Eval(
302 "tensor cells must all be scalar number values".to_owned(),
303 ));
304 };
305 if number.domain == number_domain() {
306 return Err(Error::Eval(
307 "tensor cells must be scalar numbers, not nested tensors".to_owned(),
308 ));
309 }
310 }
311 Ok(())
312}
313
314fn choose_dtype(cx: &mut Cx, dtype_hint: Option<Symbol>, data: &[Value]) -> Result<Symbol> {
315 let domains = data
316 .iter()
317 .map(|value| {
318 cx.number_value_ref(value.clone())?
319 .map(|number| number.domain)
320 .ok_or_else(|| {
321 Error::Eval("tensor cells must all be scalar number values".to_owned())
322 })
323 })
324 .collect::<Result<Vec<_>>>()?;
325 let Some(first) = domains.first() else {
326 return Err(Error::Eval("tensor requires at least one cell".to_owned()));
327 };
328 if let Some(dtype) = dtype_hint {
329 if domains
330 .iter()
331 .all(|domain| promotion_cost(cx, domain, &dtype).is_some())
332 {
333 return Ok(dtype);
334 }
335 return Err(Error::Eval(format!(
336 "tensor dtype {dtype} is not a valid join for cell domains {domains:?}"
337 )));
338 }
339 let candidates = cx
340 .registry()
341 .number_domains()
342 .keys()
343 .filter(|symbol| **symbol != number_domain())
344 .cloned()
345 .collect::<Vec<_>>();
346 let mut best = None::<(u32, Symbol)>;
347 for candidate in candidates {
348 let mut total = 0u32;
349 let mut valid = true;
350 for domain in &domains {
351 let Some(cost) = promotion_cost(cx, domain, &candidate) else {
352 valid = false;
353 break;
354 };
355 total += cost;
356 }
357 if !valid {
358 continue;
359 }
360 match &best {
361 Some((best_cost, best_symbol))
362 if total > *best_cost || (total == *best_cost && candidate >= *best_symbol) => {}
363 _ => best = Some((total, candidate)),
364 }
365 }
366 best.map(|(_, symbol)| symbol)
367 .ok_or_else(|| {
368 Error::Eval(format!(
369 "no join domain exists for tensor cells {domains:?}"
370 ))
371 })
372 .or_else(|_| Ok(first.clone()))
373}
374
375fn promotion_cost(cx: &Cx, from: &Symbol, to: &Symbol) -> Option<u32> {
376 if from == to {
377 return Some(0);
378 }
379
380 #[derive(Clone, Eq, PartialEq)]
381 struct State {
382 cost: u32,
383 symbol: Symbol,
384 }
385
386 impl Ord for State {
387 fn cmp(&self, other: &Self) -> Ordering {
388 other
389 .cost
390 .cmp(&self.cost)
391 .then_with(|| other.symbol.cmp(&self.symbol))
392 }
393 }
394
395 impl PartialOrd for State {
396 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
397 Some(self.cmp(other))
398 }
399 }
400
401 let mut best = BTreeMap::<Symbol, u32>::new();
402 let mut heap = BinaryHeap::new();
403 best.insert(from.clone(), 0);
404 heap.push(State {
405 cost: 0,
406 symbol: from.clone(),
407 });
408
409 while let Some(State { cost, symbol }) = heap.pop() {
410 if &symbol == to {
411 return Some(cost);
412 }
413 if best.get(&symbol).copied().unwrap_or(u32::MAX) < cost {
414 continue;
415 }
416 for rule in cx
417 .registry()
418 .value_promotion_rules()
419 .iter()
420 .filter(|rule| rule.from_domain == symbol)
421 {
422 let next = cost + rule.cost as u32;
423 let entry = best.entry(rule.to_domain.clone()).or_insert(u32::MAX);
424 if next < *entry {
425 *entry = next;
426 heap.push(State {
427 cost: next,
428 symbol: rule.to_domain.clone(),
429 });
430 }
431 }
432 }
433 None
434}