1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
use crate::internal::*; #[derive(Debug, Clone, new)] pub struct SourceState(pub usize); impl OpState for SourceState { fn eval( &mut self, session: &mut SessionState, _op: &dyn Op, _inputs: TVec<Arc<Tensor>>, ) -> TractResult<TVec<Arc<Tensor>>> { Ok(tvec!(session.inputs[&self.0].clone())) } } #[derive(Debug, Clone, new, Hash)] pub struct TypedSource { pub fact: TypedFact, } tract_linalg::impl_dyn_hash!(TypedSource); impl Op for TypedSource { fn name(&self) -> Cow<str> { "Source".into() } canonic!(); op_core_lir_mir!(); op_as_typed_op!(); not_a_pulsed_op!(); } impl StatefullOp for TypedSource { fn state( &self, _session: &mut SessionState, node_id: usize, ) -> TractResult<Option<Box<dyn OpState>>> { Ok(Some(Box::new(SourceState(node_id)))) } } impl TypedOp for TypedSource { fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> { Ok(tvec!(self.fact.clone())) } fn change_axes( &self, model: &TypedModel, node: &TypedNode, _io: InOut, change: &AxisOp, ) -> TractResult<Option<AxisChangeConsequence>> { let mut fact = self.fact.clone(); change.change_shape(&mut fact.shape)?; Ok(Some(AxisChangeConsequence::new( model, node, Some(Box::new(TypedSource::new(fact))), change, ))) } fn pulsify( &self, _source: &TypedModel, node: &TypedNode, target: &mut PulsedModel, _mapping: &HashMap<OutletId, OutletId>, pulse: usize, ) -> TractResult<TVec<OutletId>> { let pulsed_fact = crate::pulse::PulsedFact::from_tensor_fact_pulse(&node.outputs[0].fact, pulse)?; let id = target.add_source(node.name.clone(), pulsed_fact)?; Ok(tvec!(id)) } fn concretize_stream_dim( &self, _source: &TypedModel, node: &TypedNode, target: &mut TypedModel, _mapping: &HashMap<OutletId, OutletId>, stream_dim: usize, ) -> TractResult<TVec<OutletId>> { let mut fact = self.fact.clone(); if let Some(info) = self.fact.shape.stream_info.as_ref() { fact.shape .set_dim(info.axis, fact.shape.dim(info.axis).concretize_stream_dim(stream_dim))?; } target.wire_node(&node.name, Self { fact }, &[]) } as_op!(); } #[derive(Debug, Clone, new, Hash)] pub struct PulsedSource { fact: PulsedFact, } tract_linalg::impl_dyn_hash!(PulsedSource); impl Op for PulsedSource { fn name(&self) -> Cow<str> { "PulsedSource".into() } canonic!(); op_core_lir_mir!(); not_a_typed_op!(); op_as_pulsed_op!(); } impl StatefullOp for PulsedSource { fn state( &self, _session: &mut SessionState, node_id: usize, ) -> TractResult<Option<Box<dyn OpState>>> { Ok(Some(Box::new(SourceState(node_id)))) } } impl PulsedOp for PulsedSource { fn pulsed_output_facts(&self, _inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> { Ok(tvec!(self.fact.clone())) } fn to_typed(&self) -> Box<dyn TypedOp> { Box::new(TypedSource::new( TypedFact::dt_shape(self.fact.datum_type, &*self.fact.shape).unwrap(), )) } as_op!(); }