1use crate::internal::*;
2
3#[derive(Clone, Debug, PartialEq, Eq, Hash)]
4pub struct StreamInfo {
5 pub axis: usize,
6 pub dim: TDim,
7 pub delay: usize,
8}
9
10pub trait StreamFact {
11 fn stream_info(&self, stream_sym: &Symbol) -> Option<(usize, &TDim)>;
12}
13
14impl StreamFact for ShapeFact {
15 fn stream_info(&self, stream_sym: &Symbol) -> Option<(usize, &TDim)> {
16 let streaming_dims: TVec<(usize, &TDim)> = (**self)
17 .iter()
18 .enumerate()
19 .filter(|(_ix, d)| d.symbols().contains(stream_sym))
20 .collect();
21 if streaming_dims.len() != 1 {
22 None
23 } else {
24 Some(streaming_dims[0])
25 }
26 }
27}
28
29#[derive(Clone, PartialEq, Eq, Hash)]
30pub struct PulsedFact {
31 pub datum_type: DatumType,
32 pub shape: ShapeFact,
33 pub stream: Option<StreamInfo>,
34}
35
36impl PulsedFact {
37 pub fn from_tensor_fact_pulse(
38 tf: &TypedFact,
39 symbol: &Symbol,
40 pulse: &TDim,
41 ) -> TractResult<PulsedFact> {
42 let datum_type = tf.datum_type;
43 let (axis, len) = tf
44 .shape
45 .stream_info(symbol)
46 .ok_or_else(|| format_err!("Can not pulse a tensor with no streaming dim"))?;
47 let mut shape: TVec<TDim> = tf.shape.to_tvec();
48 shape[axis] = pulse.clone();
49 Ok(PulsedFact {
50 datum_type,
51 shape: shape.into(),
52 stream: Some(StreamInfo { axis, dim: len.clone(), delay: 0 }),
53 })
54 }
55
56 pub fn pulse(&self) -> Option<&TDim> {
57 if let Some(stream) = &self.stream {
58 Some(&self.shape[stream.axis])
59 } else {
60 None
61 }
62 }
63
64 pub fn to_pulse_fact(&self) -> TypedFact {
65 self.datum_type.fact(self.shape.clone())
66 }
67
68 pub fn streaming_shape(&self) -> TVec<TDim> {
69 if let Some(stream) = &self.stream {
70 self.shape
71 .iter()
72 .enumerate()
73 .map(|(ix, d)| if ix == stream.axis { stream.dim.clone() } else { d.clone() })
74 .collect()
75 } else {
76 self.shape.to_tvec()
77 }
78 }
79
80 pub fn to_streaming_fact(&self) -> TypedFact {
81 let mut info = self.to_pulse_fact();
82 info.shape = self.streaming_shape().into();
83 info
84 }
85}
86
87impl fmt::Debug for PulsedFact {
88 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
89 use tract_itertools::Itertools;
90 if let Some(stream) = &self.stream {
91 write!(
92 fmt,
93 "{},{:?} [pulse axis:{} ∂:{} full dim:{}]",
94 self.shape.iter().join(","),
95 self.datum_type,
96 stream.axis,
97 stream.delay,
98 stream.dim
99 )
100 } else {
101 write!(fmt, "{:?}", self.to_pulse_fact())
102 }
103 }
104}
105
106impl Fact for PulsedFact {
107 fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
108 Ok(Cow::Owned(self.into()))
109 }
110
111 fn same_as(&self, other: &dyn Fact) -> bool {
112 if let Some(other) = other.downcast_ref::<PulsedFact>() {
113 other == self
114 } else {
115 false
116 }
117 }
118
119 fn compatible_with(&self, other: &dyn Fact) -> bool {
120 self.same_as(other)
121 }
122
123 fn datum_type(&self) -> Option<DatumType> {
124 Some(self.datum_type)
125 }
126}
127
128impl From<PulsedFact> for TypedFact {
129 fn from(fact: PulsedFact) -> TypedFact {
130 fact.datum_type.fact(fact.shape)
131 }
132}
133
134impl<'a> From<&'a PulsedFact> for TypedFact {
135 fn from(fact: &'a PulsedFact) -> TypedFact {
136 fact.datum_type.fact(fact.shape.clone())
137 }
138}