Skip to main content

sim_lib_numbers_tensor/implementation/
value.rs

1//! The uniform `Tensor` value type: its shape, dtype, and cell storage, with
2//! indexing, construction, and number-value behavior backing the tensor domain.
3
4use std::cmp::Ordering;
5use std::collections::{BTreeMap, BinaryHeap};
6use std::sync::Arc;
7
8use sim_kernel::{
9    ClassRef, Cx, DefaultFactory, Error, Expr, Factory, NumberValue, Object, ObjectCompat,
10    ObjectEncode, ObjectEncoding, Result, Symbol, Value,
11};
12
13use super::citizen::tensor_value_class_symbol;
14use super::domain::number_domain;
15
16/// The uniform tensor value: an n-dimensional array of scalar number cells.
17///
18/// A tensor is row-major (last axis varies fastest) and homogeneous: every cell
19/// shares the [`dtype`](Tensor::dtype) number domain. An empty [`shape`](Tensor::shape)
20/// denotes a rank-0 scalar holding a single cell. Tensors are the value backing
21/// the `numbers/tensor` domain and are constructed through
22/// [`build_tensor_value`] rather than parsed from literals.
23#[derive(Clone)]
24pub struct Tensor {
25    /// Length of each axis, outermost first. Empty for a rank-0 scalar.
26    pub shape: Vec<usize>,
27    /// The shared scalar number domain of every cell (for example
28    /// `numbers/i64` or `numbers/f64`).
29    pub dtype: Symbol,
30    /// Row-major cell storage; its length equals the product of `shape`
31    /// (one for a scalar).
32    pub data: Vec<Value>,
33}
34
35impl Tensor {
36    /// The number of axes, i.e. the length of [`shape`](Tensor::shape). Zero
37    /// for a scalar.
38    pub fn rank(&self) -> usize {
39        self.shape.len()
40    }
41
42    /// Computes the row-major flat offset into [`data`](Tensor::data) for a
43    /// multi-dimensional `indices` coordinate against `shape`.
44    ///
45    /// Returns an error if the index rank does not match `shape` or any
46    /// component is out of bounds.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use sim_lib_numbers_tensor::Tensor;
52    ///
53    /// // Row-major 2x3 tensor: element (1, 2) is at flat offset 5.
54    /// assert_eq!(Tensor::flat_offset(&[2, 3], &[1, 2]).unwrap(), 5);
55    /// assert_eq!(Tensor::flat_offset(&[2, 3], &[0, 0]).unwrap(), 0);
56    /// // Out-of-bounds and rank-mismatched indices are rejected.
57    /// assert!(Tensor::flat_offset(&[2, 3], &[2, 0]).is_err());
58    /// assert!(Tensor::flat_offset(&[2, 3], &[0]).is_err());
59    /// ```
60    pub fn flat_offset(shape: &[usize], indices: &[usize]) -> Result<usize> {
61        if shape.len() != indices.len() {
62            return Err(Error::Eval("tensor index rank mismatch".to_owned()));
63        }
64        let mut stride = 1usize;
65        let mut offset = 0usize;
66        for (dim, index) in shape.iter().rev().zip(indices.iter().rev()) {
67            if *index >= *dim {
68                return Err(Error::Eval("tensor index was out of bounds".to_owned()));
69            }
70            offset += index * stride;
71            stride = stride.saturating_mul(*dim);
72        }
73        Ok(offset)
74    }
75
76    /// Enumerates every multi-dimensional coordinate of `shape` in row-major
77    /// order. An empty shape yields a single empty coordinate (the scalar cell).
78    pub fn coordinates(shape: &[usize]) -> Vec<Vec<usize>> {
79        if shape.is_empty() {
80            return vec![Vec::new()];
81        }
82        let mut out = Vec::new();
83        let mut coord = vec![0usize; shape.len()];
84        loop {
85            out.push(coord.clone());
86            let mut axis = shape.len();
87            while axis > 0 {
88                axis -= 1;
89                coord[axis] += 1;
90                if coord[axis] < shape[axis] {
91                    break;
92                }
93                coord[axis] = 0;
94                if axis == 0 {
95                    return out;
96                }
97            }
98        }
99    }
100}
101
102impl Object for Tensor {
103    fn display(&self, cx: &mut Cx) -> Result<String> {
104        match self.as_expr(cx)? {
105            Expr::Call { .. } => Ok(format!("{}<{:?}>", tensor_display_name(), self.shape)),
106            expr => Ok(format!("{expr:?}")),
107        }
108    }
109
110    fn as_any(&self) -> &dyn std::any::Any {
111        self
112    }
113}
114
115impl sim_kernel::ObjectCompat for Tensor {
116    fn class(&self, cx: &mut Cx) -> Result<ClassRef> {
117        if let Some(value) = cx.registry().class_by_symbol(&tensor_value_class_symbol()) {
118            return Ok(value.clone());
119        }
120        if let Some(value) = cx
121            .registry()
122            .class_by_symbol(&Symbol::qualified("core", "Number"))
123        {
124            return Ok(value.clone());
125        }
126        DefaultFactory.class_stub(
127            sim_kernel::CORE_NUMBER_CLASS_ID,
128            Symbol::qualified("core", "Number"),
129        )
130    }
131    fn as_expr(&self, cx: &mut Cx) -> Result<Expr> {
132        match self.rank() {
133            0 => Ok(Expr::Call {
134                operator: Box::new(Expr::Symbol(Symbol::new("scalar"))),
135                args: vec![self.data[0].object().as_expr(cx)?],
136            }),
137            1 => Ok(Expr::Vector(exprs(cx, &self.data)?)),
138            2 => {
139                let width = self.shape[1];
140                let rows = self
141                    .data
142                    .chunks(width)
143                    .map(|row| exprs(cx, row).map(Expr::Vector))
144                    .collect::<Result<Vec<_>>>()?;
145                Ok(Expr::Vector(rows))
146            }
147            _ => Ok(Expr::Call {
148                operator: Box::new(Expr::Symbol(Symbol::new("tensor"))),
149                args: vec![
150                    Expr::Vector(
151                        self.shape
152                            .iter()
153                            .map(|dim| Expr::String(dim.to_string()))
154                            .collect(),
155                    ),
156                    Expr::Symbol(self.dtype.clone()),
157                    Expr::Vector(exprs(cx, &self.data)?),
158                ],
159            }),
160        }
161    }
162    fn as_table(&self, cx: &mut Cx) -> Result<Value> {
163        let shape = cx.factory().list(
164            self.shape
165                .iter()
166                .map(|dim| cx.factory().string(dim.to_string()))
167                .collect::<Result<Vec<_>>>()?,
168        )?;
169        let data = cx.factory().list(self.data.clone())?;
170        cx.factory().table(vec![
171            (
172                Symbol::new("kind"),
173                cx.factory().string("tensor".to_owned())?,
174            ),
175            (Symbol::new("shape"), shape),
176            (
177                Symbol::new("dtype"),
178                cx.factory().symbol(self.dtype.clone())?,
179            ),
180            (Symbol::new("data"), data),
181        ])
182    }
183    fn as_number_value(&self) -> Option<&dyn NumberValue> {
184        Some(self)
185    }
186
187    fn as_object_encoder(&self) -> Option<&dyn ObjectEncode> {
188        Some(self)
189    }
190}
191
192impl NumberValue for Tensor {
193    fn number_domain(&self, _cx: &mut Cx) -> Result<Symbol> {
194        Ok(number_domain())
195    }
196}
197
198impl ObjectEncode for Tensor {
199    fn object_encoding(&self, cx: &mut Cx) -> Result<ObjectEncoding> {
200        Ok(ObjectEncoding::Constructor {
201            class: tensor_value_class_symbol(),
202            args: vec![
203                Expr::Symbol(Symbol::new("v1")),
204                Expr::List(
205                    self.shape
206                        .iter()
207                        .map(|dim| {
208                            Expr::Number(sim_kernel::NumberLiteral {
209                                domain: Symbol::qualified("citizen", "int"),
210                                canonical: dim.to_string(),
211                            })
212                        })
213                        .collect(),
214                ),
215                Expr::List(exprs(cx, &self.data)?),
216                Expr::Symbol(self.dtype.clone()),
217            ],
218        })
219    }
220}
221
222impl sim_citizen::Citizen for Tensor {
223    fn citizen_symbol() -> Symbol {
224        tensor_value_class_symbol()
225    }
226
227    fn citizen_version() -> u32 {
228        1
229    }
230
231    fn citizen_arity() -> usize {
232        3
233    }
234
235    fn citizen_fields() -> &'static [&'static str] {
236        &["shape", "data", "domain"]
237    }
238}
239
240/// Builds a tensor [`Value`] of the given `shape` from row-major `data` cells.
241///
242/// The cell count must equal the product of `shape` (one for an empty, scalar
243/// shape). Every cell must be a scalar number value (not a nested tensor). When
244/// `dtype_hint` is `Some`, all cells must promote to that domain; otherwise the
245/// element domain is chosen as the cheapest join of the cell domains. Returns an
246/// error on a cell-count mismatch, a non-scalar cell, or an impossible dtype.
247pub fn build_tensor_value(
248    cx: &mut Cx,
249    shape: Vec<usize>,
250    dtype_hint: Option<Symbol>,
251    data: Vec<Value>,
252) -> Result<Value> {
253    let expected = checked_element_count(&shape)?;
254    if data.len() != expected {
255        return Err(Error::Eval(format!(
256            "tensor shape {:?} expects {expected} cells, found {}",
257            shape,
258            data.len()
259        )));
260    }
261    validate_cells(cx, &data)?;
262    let dtype = choose_dtype(cx, dtype_hint, &data)?;
263    cx.factory().opaque(Arc::new(Tensor { shape, dtype, data }))
264}
265
266/// Builds a rank-0 scalar tensor wrapping a single scalar number `value`.
267pub fn build_scalar_tensor_value(cx: &mut Cx, value: Value) -> Result<Value> {
268    build_tensor_value(cx, Vec::new(), None, vec![value])
269}
270
271/// Borrows the [`Tensor`] backing a value, or `None` if it is not a tensor.
272pub fn tensor_value_ref(value: &Value) -> Option<&Tensor> {
273    value.object().downcast_ref::<Tensor>()
274}
275
276/// The shared element number domain (dtype) of a tensor's cells.
277pub fn tensor_dtype(tensor: &Tensor) -> &Symbol {
278    &tensor.dtype
279}
280
281/// Clones a tensor's row-major cell values as a flat vector.
282pub fn flatten_tensor_scalar_cells(tensor: &Tensor) -> Vec<Value> {
283    tensor.data.clone()
284}
285
286pub fn tensor_display_name() -> &'static str {
287    "tensor"
288}
289
290fn exprs(cx: &mut Cx, data: &[Value]) -> Result<Vec<Expr>> {
291    data.iter()
292        .map(|value| value.object().as_expr(cx))
293        .collect()
294}
295
296use crate::spec::checked_element_count;
297
298fn validate_cells(cx: &mut Cx, data: &[Value]) -> Result<()> {
299    for cell in data {
300        let Some(number) = cx.number_value_ref(cell.clone())? else {
301            return Err(Error::Eval(
302                "tensor cells must all be scalar number values".to_owned(),
303            ));
304        };
305        if number.domain == number_domain() {
306            return Err(Error::Eval(
307                "tensor cells must be scalar numbers, not nested tensors".to_owned(),
308            ));
309        }
310    }
311    Ok(())
312}
313
314fn choose_dtype(cx: &mut Cx, dtype_hint: Option<Symbol>, data: &[Value]) -> Result<Symbol> {
315    let domains = data
316        .iter()
317        .map(|value| {
318            cx.number_value_ref(value.clone())?
319                .map(|number| number.domain)
320                .ok_or_else(|| {
321                    Error::Eval("tensor cells must all be scalar number values".to_owned())
322                })
323        })
324        .collect::<Result<Vec<_>>>()?;
325    let Some(first) = domains.first() else {
326        return Err(Error::Eval("tensor requires at least one cell".to_owned()));
327    };
328    if let Some(dtype) = dtype_hint {
329        if domains
330            .iter()
331            .all(|domain| promotion_cost(cx, domain, &dtype).is_some())
332        {
333            return Ok(dtype);
334        }
335        return Err(Error::Eval(format!(
336            "tensor dtype {dtype} is not a valid join for cell domains {domains:?}"
337        )));
338    }
339    let candidates = cx
340        .registry()
341        .number_domains()
342        .keys()
343        .filter(|symbol| **symbol != number_domain())
344        .cloned()
345        .collect::<Vec<_>>();
346    let mut best = None::<(u32, Symbol)>;
347    for candidate in candidates {
348        let mut total = 0u32;
349        let mut valid = true;
350        for domain in &domains {
351            let Some(cost) = promotion_cost(cx, domain, &candidate) else {
352                valid = false;
353                break;
354            };
355            total += cost;
356        }
357        if !valid {
358            continue;
359        }
360        match &best {
361            Some((best_cost, best_symbol))
362                if total > *best_cost || (total == *best_cost && candidate >= *best_symbol) => {}
363            _ => best = Some((total, candidate)),
364        }
365    }
366    best.map(|(_, symbol)| symbol)
367        .ok_or_else(|| {
368            Error::Eval(format!(
369                "no join domain exists for tensor cells {domains:?}"
370            ))
371        })
372        .or_else(|_| Ok(first.clone()))
373}
374
375fn promotion_cost(cx: &Cx, from: &Symbol, to: &Symbol) -> Option<u32> {
376    if from == to {
377        return Some(0);
378    }
379
380    #[derive(Clone, Eq, PartialEq)]
381    struct State {
382        cost: u32,
383        symbol: Symbol,
384    }
385
386    impl Ord for State {
387        fn cmp(&self, other: &Self) -> Ordering {
388            other
389                .cost
390                .cmp(&self.cost)
391                .then_with(|| other.symbol.cmp(&self.symbol))
392        }
393    }
394
395    impl PartialOrd for State {
396        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
397            Some(self.cmp(other))
398        }
399    }
400
401    let mut best = BTreeMap::<Symbol, u32>::new();
402    let mut heap = BinaryHeap::new();
403    best.insert(from.clone(), 0);
404    heap.push(State {
405        cost: 0,
406        symbol: from.clone(),
407    });
408
409    while let Some(State { cost, symbol }) = heap.pop() {
410        if &symbol == to {
411            return Some(cost);
412        }
413        if best.get(&symbol).copied().unwrap_or(u32::MAX) < cost {
414            continue;
415        }
416        for rule in cx
417            .registry()
418            .value_promotion_rules()
419            .iter()
420            .filter(|rule| rule.from_domain == symbol)
421        {
422            let next = cost + rule.cost as u32;
423            let entry = best.entry(rule.to_domain.clone()).or_insert(u32::MAX);
424            if next < *entry {
425                *entry = next;
426                heap.push(State {
427                    cost: next,
428                    symbol: rule.to_domain.clone(),
429                });
430            }
431        }
432    }
433    None
434}