tract_pulse/
fact.rs

1use crate::internal::*;
2
3#[derive(Clone, Debug, PartialEq, Eq, Hash)]
4pub struct StreamInfo {
5    pub axis: usize,
6    pub dim: TDim,
7    pub delay: usize,
8}
9
10pub trait StreamFact {
11    fn stream_info(&self, stream_sym: &Symbol) -> Option<(usize, &TDim)>;
12}
13
14impl StreamFact for ShapeFact {
15    fn stream_info(&self, stream_sym: &Symbol) -> Option<(usize, &TDim)> {
16        let streaming_dims: TVec<(usize, &TDim)> = (**self)
17            .iter()
18            .enumerate()
19            .filter(|(_ix, d)| d.symbols().contains(stream_sym))
20            .collect();
21        if streaming_dims.len() != 1 {
22            None
23        } else {
24            Some(streaming_dims[0])
25        }
26    }
27}
28
29#[derive(Clone, PartialEq, Eq, Hash)]
30pub struct PulsedFact {
31    pub datum_type: DatumType,
32    pub shape: ShapeFact,
33    pub stream: Option<StreamInfo>,
34}
35
36impl PulsedFact {
37    pub fn from_tensor_fact_pulse(
38        tf: &TypedFact,
39        symbol: &Symbol,
40        pulse: &TDim,
41    ) -> TractResult<PulsedFact> {
42        let datum_type = tf.datum_type;
43        let (axis, len) = tf
44            .shape
45            .stream_info(symbol)
46            .ok_or_else(|| format_err!("Can not pulse a tensor with no streaming dim"))?;
47        let mut shape: TVec<TDim> = tf.shape.to_tvec();
48        shape[axis] = pulse.clone();
49        Ok(PulsedFact {
50            datum_type,
51            shape: shape.into(),
52            stream: Some(StreamInfo { axis, dim: len.clone(), delay: 0 }),
53        })
54    }
55
56    pub fn pulse(&self) -> Option<&TDim> {
57        if let Some(stream) = &self.stream {
58            Some(&self.shape[stream.axis])
59        } else {
60            None
61        }
62    }
63
64    pub fn to_pulse_fact(&self) -> TypedFact {
65        self.datum_type.fact(self.shape.clone())
66    }
67
68    pub fn streaming_shape(&self) -> TVec<TDim> {
69        if let Some(stream) = &self.stream {
70            self.shape
71                .iter()
72                .enumerate()
73                .map(|(ix, d)| if ix == stream.axis { stream.dim.clone() } else { d.clone() })
74                .collect()
75        } else {
76            self.shape.to_tvec()
77        }
78    }
79
80    pub fn to_streaming_fact(&self) -> TypedFact {
81        let mut info = self.to_pulse_fact();
82        info.shape = self.streaming_shape().into();
83        info
84    }
85}
86
87impl fmt::Debug for PulsedFact {
88    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
89        use tract_itertools::Itertools;
90        if let Some(stream) = &self.stream {
91            write!(
92                fmt,
93                "{},{:?} [pulse axis:{} ∂:{} full dim:{}]",
94                self.shape.iter().join(","),
95                self.datum_type,
96                stream.axis,
97                stream.delay,
98                stream.dim
99            )
100        } else {
101            write!(fmt, "{:?}", self.to_pulse_fact())
102        }
103    }
104}
105
106impl Fact for PulsedFact {
107    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
108        Ok(Cow::Owned(self.into()))
109    }
110
111    fn same_as(&self, other: &dyn Fact) -> bool {
112        if let Some(other) = other.downcast_ref::<PulsedFact>() {
113            other == self
114        } else {
115            false
116        }
117    }
118
119    fn compatible_with(&self, other: &dyn Fact) -> bool {
120        self.same_as(other)
121    }
122
123    fn datum_type(&self) -> Option<DatumType> {
124        Some(self.datum_type)
125    }
126}
127
128impl From<PulsedFact> for TypedFact {
129    fn from(fact: PulsedFact) -> TypedFact {
130        fact.datum_type.fact(fact.shape)
131    }
132}
133
134impl<'a> From<&'a PulsedFact> for TypedFact {
135    fn from(fact: &'a PulsedFact) -> TypedFact {
136        fact.datum_type.fact(fact.shape.clone())
137    }
138}