tract_hir/ops/array/
size.rs1use crate::infer::*;
2use crate::internal::*;
3
4#[derive(Debug, Clone, new, Hash)]
5pub struct Size {
6 pub dt: DatumType,
7}
8
9
10impl Expansion for Size {
11 fn name(&self) -> StaticName {
12 "Size".into()
13 }
14
15
16 fn rules<'r, 'p: 'r, 's: 'r>(
17 &'s self,
18 s: &mut Solver<'r>,
19 inputs: &'p [TensorProxy],
20 outputs: &'p [TensorProxy],
21 ) -> InferenceResult {
22 check_input_arity(inputs, 1)?;
23 check_output_arity(outputs, 1)?;
24 s.equals(&outputs[0].datum_type, self.dt)?;
25 s.equals(&outputs[0].rank, 0)?;
26 Ok(())
27 }
28
29 fn wire(
30 &self,
31 prefix: &str,
32 model: &mut TypedModel,
33 inputs: &[OutletId],
34 ) -> TractResult<TVec<OutletId>> {
35 let mut size = tensor0(model.outlet_fact(inputs[0])?.shape.iter().product::<TDim>());
36 if let Ok(s) = size.cast_to_dt(self.dt) {
37 size = s.into_owned();
38 }
39 let wire = model.add_const(prefix, size)?;
40 Ok(tvec!(wire))
41 }
42}