tract_hir/ops/array/
size.rs

1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct Size {
6    pub dt: DatumType,
7}
8
9
10impl Expansion for Size {
11    fn name(&self) -> StaticName {
12        "Size".into()
13    }
14
15
16    fn rules<'r, 'p: 'r, 's: 'r>(
17        &'s self,
18        s: &mut Solver<'r>,
19        inputs: &'p [TensorProxy],
20        outputs: &'p [TensorProxy],
21    ) -> InferenceResult {
22        check_input_arity(inputs, 1)?;
23        check_output_arity(outputs, 1)?;
24        s.equals(&outputs[0].datum_type, self.dt)?;
25        s.equals(&outputs[0].rank, 0)?;
26        Ok(())
27    }
28
29    fn wire(
30        &self,
31        prefix: &str,
32        model: &mut TypedModel,
33        inputs: &[OutletId],
34    ) -> TractResult<TVec<OutletId>> {
35        let mut size = tensor0(model.outlet_fact(inputs[0])?.shape.iter().product::<TDim>());
36        if let Ok(s) = size.cast_to_dt(self.dt) {
37            size = s.into_owned();
38        }
39        let wire = model.add_const(prefix, size)?;
40        Ok(tvec!(wire))
41    }
42}