1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
use crate::internal::*; #[derive(Debug, Clone, new)] pub struct SourceState(pub usize); impl OpState for SourceState { fn eval( &mut self, session: &mut SessionState, _op: &dyn Op, _inputs: TVec<Arc<Tensor>>, ) -> TractResult<TVec<Arc<Tensor>>> { Ok(tvec!(session.inputs[&self.0].clone())) } } #[derive(Debug, Clone, new, Hash)] pub struct TypedSource { pub fact: TypedFact, } impl_dyn_hash!(TypedSource); impl Op for TypedSource { fn name(&self) -> Cow<str> { "Source".into() } op_core_lir_mir!(); op_as_typed_op!(); } impl EvalOp for TypedSource { fn is_stateless(&self) -> bool { false } fn state( &self, _session: &mut SessionState, node_id: usize, ) -> TractResult<Option<Box<dyn OpState>>> { Ok(Some(Box::new(SourceState(node_id)))) } } impl TypedOp for TypedSource { fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { Ok(tvec!(self.fact.clone())) } fn change_axes( &self, model: &TypedModel, node: &TypedNode, _io: InOut, change: &AxisOp, ) -> TractResult<Option<AxisChangeConsequence>> { let mut fact = self.fact.clone(); change.change_shape(&mut fact.shape)?; Ok(Some(AxisChangeConsequence::new( model, node, Some(Box::new(TypedSource::new(fact))), change, ))) } fn concretize_dims( &self, _source: &TypedModel, node: &TypedNode, target: &mut TypedModel, _mapping: &HashMap<OutletId, OutletId>, values: &SymbolValues, ) -> TractResult<TVec<OutletId>> { let shape: TVec<_> = self.fact.shape.iter().map(|d| d.eval(values)).collect(); target.wire_node( &node.name, Self { fact: TypedFact::dt_shape(self.fact.datum_type, &*shape) }, &[], ) } as_op!(); }