1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
use crate::internal::*; #[derive(Debug, Clone, new, Hash)] pub struct ConstantOfShape { shape: TVec<TDim>, scalar: Arc<Tensor>, } impl_dyn_hash!(ConstantOfShape); impl Op for ConstantOfShape { fn name(&self) -> Cow<str> { "ConstantOfShape".into() } op_core!(); op_as_typed_op!(); } impl TypedOp for ConstantOfShape { fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { if self.scalar.rank() > 0 { bail!("ConstantOfShape attribute must be a scalar, {:?}", self.scalar) } let mut fact = TypedFact::dt_shape(self.scalar.datum_type(), &self.shape); fact.uniform = Some(self.scalar.clone()); Ok(tvec!(fact)) } fn declutter( &self, model: &TypedModel, node: &TypedNode, ) -> TractResult<Option<TypedModelPatch>> { if let Ok(shape) = self.shape.iter().map(|d| d.to_usize()).collect::<TractResult<Vec<usize>>>() { let tensor = self.scalar.broadcast_scalar_to_shape(&*shape)?.into_arc_tensor(); Ok(Some(TypedModelPatch::replace_single_op( model, node, &[], crate::ops::konst::Const::new(tensor), )?)) } else { Ok(None) } } as_op!(); } impl EvalOp for ConstantOfShape { fn eval(&self, _inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> { let shape: TVec<_> = self.shape.iter().map(|d| d.to_usize()).collect::<TractResult<_>>()?; Ok(tvec!(self.scalar.broadcast_scalar_to_shape(&*shape)?.into_arc_tensor())) } fn is_stateless(&self) -> bool { self.shape.iter().all(|d| d.to_usize().is_ok()) } fn state( &self, _session: &mut SessionState, _node_id: usize, ) -> TractResult<Option<Box<dyn OpState>>> { Ok(Some(Box::new(ConstantOfShapeState))) } } #[derive(Clone, Debug)] struct ConstantOfShapeState; impl OpState for ConstantOfShapeState { fn eval( &mut self, session: &mut SessionState, op: &dyn Op, _inputs: TVec<Arc<Tensor>>, ) -> TractResult<TVec<Arc<Tensor>>> { let op = op.downcast_ref::<ConstantOfShape>().unwrap(); let shape = op .shape .iter() .map(|d| Ok(d.eval(&session.resolved_symbols).to_usize()?)) .collect::<TractResult<TVec<_>>>()?; Ok(tvec!(op.scalar.broadcast_scalar_to_shape(&*shape)?.into_arc_tensor())) } }