1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
use crate::internal::*; use std::fmt; use std::convert::TryFrom; pub mod delay; #[derive(Clone, PartialEq)] pub struct PulsedTensorFact { pub dt: DatumType, pub shape: TVec<usize>, pub axis: usize, pub dim: TDim, pub delay: usize, } impl fmt::Debug for PulsedTensorFact { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { use itertools::Itertools; write!( fmt, "{}x{:?} [pulse axis:{} ∂:{} full dim:{:?}]", self.shape.iter().join("x"), self.dt, self.axis, self.delay, self.dim ) } } impl TensorInfo for PulsedTensorFact { fn to_tensor_fact(&self) -> TensorFact { TensorFact::dt_shape(self.dt, &self.shape) } } impl TryFrom<PulsedTensorFact> for TypedTensorInfo { type Error = TractError; fn try_from(fact: PulsedTensorFact) -> TractResult<TypedTensorInfo> { TypedTensorInfo::dt_shape(fact.dt, &*fact.shape) } } impl PulsedTensorFact { pub fn from_tensor_fact_pulse( tf: &NormalizedTensorInfo, pulse: usize, ) -> TractResult<PulsedTensorFact> { let dt = tf.datum_type; let stream = tf.shape.stream_info.as_ref().ok_or("Can not pulse a tensor with no streaming dim")?; let shape = tf.shape.iter().map(|d| d.to_integer().map(|d| d as usize).unwrap_or(pulse)).collect(); Ok(PulsedTensorFact { dt, shape, axis: stream.axis, dim: stream.len.clone(), delay: 0 }) } pub fn pulse(&self) -> usize { self.shape[self.axis] } pub fn to_pulse_fact(&self) -> NormalizedTensorInfo { NormalizedTensorInfo::dt_shape(self.dt, &*self.shape).unwrap() } pub fn streaming_shape(&self) -> Vec<TDim> { self.shape .iter() .enumerate() .map(|(ix, &d)| if ix == self.axis { self.dim.clone() } else { d.to_dim() }) .collect() } pub fn to_streaming_fact(&self) -> NormalizedTensorInfo { let mut info = self.to_pulse_fact(); info.shape.stream_info = Some(StreamInfo { axis: self.axis, len: self.dim.clone() }); info } } pub type PulsedModel = ModelImpl<PulsedTensorFact, Box<dyn TypedOp>>; impl PulsedModel { pub fn new(source: &NormalizedModel, pulse: usize) -> TractResult<PulsedModel> { Ok(PulsedModel::new_with_mapping(source, pulse)?.0) } pub fn new_with_mapping( source: &NormalizedModel, pulse: usize, ) -> TractResult<(PulsedModel, HashMap<OutletId, OutletId>)> { crate::model::compact::translate(source, &pulse) } pub fn into_typed(self) -> TractResult<TypedModel> { crate::model::compact::compact(&self) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_must_stream() { let mut model = InferenceModel::default(); let _a = model.add_source("a", TensorFact::dt_shape(DatumType::F32, vec![1, 2, 3])).unwrap(); model.auto_outputs().unwrap(); assert!( PulsedModel::new(&model.into_typed().unwrap().into_normalized().unwrap(), 4).is_err() ); let mut model = InferenceModel::default(); let _a = model .add_source( "a", TensorFact::dt_shape(DatumType::F32, vec![1.to_dim(), TDim::s(), 3.to_dim()]), ) .unwrap(); model.auto_outputs().unwrap(); let pulse = PulsedModel::new(&model.into_typed().unwrap().into_normalized().unwrap(), 4).unwrap(); assert_eq!( pulse.outlet_fact(OutletId::new(0, 0)).unwrap().to_tensor_fact(), TensorFact::dt_shape(DatumType::F32, vec!(1, 4, 3)) ); } #[test] fn test_immediate() { let mut model = InferenceModel::default(); let _a = model .add_source( "a", TensorFact::dt_shape(DatumType::F32, vec![TDim::s(), 2.to_dim(), 3.to_dim()]), ) .unwrap(); model.auto_outputs().unwrap(); let pulse = PulsedModel::new(&model.into_normalized().unwrap(), 4).unwrap(); assert_eq!( pulse.input_fact(0).unwrap().to_tensor_fact(), TensorFact::dt_shape(DatumType::F32, vec!(4, 2, 3)) ); assert_eq!( pulse.output_fact(0).unwrap().to_tensor_fact(), TensorFact::dt_shape(DatumType::F32, vec!(4, 2, 3)) ); } }