1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use crate::internal::*;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct StreamInfo {
    pub axis: usize,
    pub dim: TDim,
    pub delay: usize,
}

pub trait StreamFact {
    fn stream_info(&self, stream_sym: &Symbol) -> Option<(usize, &TDim)>;
}

impl StreamFact for ShapeFact {
    fn stream_info(&self, stream_sym: &Symbol) -> Option<(usize, &TDim)> {
        let streaming_dims: TVec<(usize, &TDim)> = (**self)
            .iter()
            .enumerate()
            .filter(|(_ix, d)| d.symbols().contains(stream_sym))
            .collect();
        if streaming_dims.len() != 1 {
            None
        } else {
            Some(streaming_dims[0])
        }
    }
}

#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PulsedFact {
    pub datum_type: DatumType,
    pub shape: ShapeFact,
    pub stream: Option<StreamInfo>,
}



impl PulsedFact {
    pub fn from_tensor_fact_pulse(
        tf: &TypedFact,
        symbol: &Symbol,
        pulse: &TDim,
    ) -> TractResult<PulsedFact> {
        let datum_type = tf.datum_type;
        let (axis, len) = tf
            .shape
            .stream_info(symbol)
            .ok_or_else(|| format_err!("Can not pulse a tensor with no streaming dim"))?;
        let mut shape: TVec<TDim> = tf.shape.iter().collect();
        shape[axis] = pulse.clone();
        Ok(PulsedFact {
            datum_type,
            shape: shape.into(),
            stream: Some(StreamInfo { axis, dim: len.clone(), delay: 0 }),
        })
    }

    pub fn pulse(&self) -> Option<&TDim> {
        if let Some(stream) = &self.stream {
            Some(&self.shape[stream.axis])
        } else {
            None
        }
    }

    pub fn to_pulse_fact(&self) -> TypedFact {
        self.datum_type.fact(self.shape.clone())
    }

    pub fn streaming_shape(&self) -> TVec<TDim> {
        if let Some(stream) = &self.stream {
            self.shape
                .iter()
                .enumerate()
                .map(|(ix, d)| if ix == stream.axis { stream.dim.clone() } else { d })
                .collect()
        } else {
            self.shape.to_tvec()
        }
    }

    pub fn to_streaming_fact(&self) -> TypedFact {
        let mut info = self.to_pulse_fact();
        info.shape = self.streaming_shape().into();
        info
    }
}

impl fmt::Debug for PulsedFact {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        use tract_itertools::Itertools;
        if let Some(stream) = &self.stream {
            write!(
                fmt,
                "{},{:?} [pulse axis:{} ∂:{} full dim:{}]",
                self.shape.iter().join(","),
                self.datum_type,
                stream.axis,
                stream.delay,
                stream.dim
            )
        } else {
            write!(fmt, "{:?}", self.to_pulse_fact())
        }
    }
}

impl Fact for PulsedFact {
    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
        Ok(Cow::Owned(self.into()))
    }

    fn same_as(&self, other: &dyn Fact) -> bool {
        if let Some(other) = other.downcast_ref::<PulsedFact>() {
            other == self
        } else {
            false
        }
    }

    fn compatible_with(&self, other: &dyn Fact) -> bool {
        self.same_as(other)
    }

    fn datum_type(&self) -> Option<DatumType> {
        Some(self.datum_type)
    }
}

impl From<PulsedFact> for TypedFact {
    fn from(fact: PulsedFact) -> TypedFact {
        fact.datum_type.fact(fact.shape)
    }
}

impl<'a> From<&'a PulsedFact> for TypedFact {
    fn from(fact: &'a PulsedFact) -> TypedFact {
        fact.datum_type.fact(fact.shape.clone())
    }
}