1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
use crate::internal::*; #[derive(Debug, Clone, new)] pub struct SourceState(pub usize); impl OpState for SourceState { fn eval( &mut self, session: &mut SessionState, _op: &dyn Op, _inputs: TVec<Arc<Tensor>>, ) -> TractResult<TVec<Arc<Tensor>>> { Ok(tvec!(session.inputs[&self.0].clone())) } } #[derive(Debug, Clone, new, Hash)] pub struct TypedSource { fact: TypedFact, } tract_linalg::impl_dyn_hash!(TypedSource); impl Op for TypedSource { fn name(&self) -> Cow<str> { "Source".into() } canonic!(); op_core_lir_mir!(); op_as_typed_op!(); not_a_pulsed_op!(); } impl StatefullOp for TypedSource { fn state( &self, _session: &mut SessionState, node_id: usize, ) -> TractResult<Option<Box<dyn OpState>>> { Ok(Some(Box::new(SourceState(node_id)))) } } impl TypedOp for TypedSource { fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { Ok(tvec!(self.fact.clone())) } fn change_axes( &self, model: &TypedModel, node: &TypedNode, _io: InOut, change: &AxisOp, ) -> TractResult<Option<AxisChangeConsequence>> { let mut fact = self.fact.clone(); change.change_shape(&mut fact.shape)?; Ok(Some(AxisChangeConsequence::new( model, node, Some(Box::new(TypedSource::new(fact))), change, ))) } fn pulsify( &self, _source: &TypedModel, node: &TypedNode, target: &mut PulsedModel, _mapping: &HashMap<OutletId, OutletId>, pulse: usize, ) -> TractResult<TVec<OutletId>> { let pulsed_fact = crate::pulse::PulsedFact::from_tensor_fact_pulse(&node.outputs[0].fact, pulse)?; let id = target.add_source(node.name.clone(), pulsed_fact)?; Ok(tvec!(id)) } as_op!(); } #[derive(Debug, Clone, new, Hash)] pub struct PulsedSource { fact: PulsedFact, } tract_linalg::impl_dyn_hash!(PulsedSource); impl Op for PulsedSource { fn name(&self) -> Cow<str> { "PulsedSource".into() } canonic!(); op_core_lir_mir!(); not_a_typed_op!(); op_as_pulsed_op!(); } impl StatefullOp for PulsedSource { fn state( &self, _session: &mut SessionState, node_id: usize, ) -> TractResult<Option<Box<dyn OpState>>> { Ok(Some(Box::new(SourceState(node_id)))) } } impl PulsedOp for PulsedSource { fn pulsed_output_facts(&self, _inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> { Ok(tvec!(self.fact.clone())) } fn to_typed(&self) -> Box<dyn TypedOp> { Box::new(TypedSource::new( TypedFact::dt_shape(self.fact.datum_type, &*self.fact.shape).unwrap(), )) } as_op!(); }