tract_hir/ops/array/
shape.rs

1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct Shape {
6    pub dt: DatumType,
7}
8
9
10impl Expansion for Shape {
11    fn name(&self) -> StaticName {
12        "Shape".into()
13    }
14
15
16    fn rules<'r, 'p: 'r, 's: 'r>(
17        &'s self,
18        s: &mut Solver<'r>,
19        inputs: &'p [TensorProxy],
20        outputs: &'p [TensorProxy],
21    ) -> InferenceResult {
22        check_input_arity(inputs, 1)?;
23        check_output_arity(outputs, 1)?;
24        s.equals(&outputs[0].rank, 1)?;
25        s.equals(&outputs[0].shape[0], inputs[0].rank.bex().to_dim())?;
26        s.equals(&outputs[0].datum_type, self.dt.bex())?;
27        s.given(&inputs[0].shape, move |s, shape| {
28            let shape = tensor1(&shape);
29            if let Ok(shape) = shape.cast_to_dt(self.dt) {
30                s.equals(&outputs[0].value, shape.into_owned().into_arc_tensor())?;
31            }
32            Ok(())
33        })
34    }
35
36    fn wire(
37        &self,
38        prefix: &str,
39        model: &mut TypedModel,
40        inputs: &[OutletId],
41    ) -> TractResult<TVec<OutletId>> {
42        let shape = tensor1(&model.outlet_fact(inputs[0])?.shape.to_tvec());
43        let wire = model.add_const(format!("{prefix}.const"), shape)?;
44        model.wire_node(prefix, tract_core::ops::cast::cast(self.dt), &[wire])
45    }
46}