1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use crate::internal::*;
lazy_static::lazy_static! {
static ref S: Symbol = Symbol::new('S');
}
pub fn stream_symbol() -> Symbol {
*S
}
pub fn stream_dim() -> TDim {
(*S).into()
}
pub trait StreamFact {
fn stream_info(&self) -> Option<(usize, &TDim)>;
}
impl StreamFact for ShapeFact {
fn stream_info(&self) -> Option<(usize, &TDim)> {
let streaming_dims: TVec<(usize, &TDim)> = (&**self)
.iter()
.enumerate()
.filter(|(_ix, d)| d.symbols().contains(&stream_symbol()))
.collect();
if streaming_dims.len() != 1 {
None
} else {
Some(streaming_dims[0])
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PulsedFact {
pub datum_type: DatumType,
pub shape: ShapeFact,
pub axis: usize,
pub dim: TDim,
pub delay: usize,
}
impl_dyn_hash!(PulsedFact);
impl PulsedFact {
pub fn from_tensor_fact_pulse(tf: &TypedFact, pulse: usize) -> TractResult<PulsedFact> {
let datum_type = tf.datum_type;
let (axis, len) = tf
.shape
.stream_info()
.ok_or_else(|| format_err!("Can not pulse a tensor with no streaming dim"))?;
let mut shape: TVec<TDim> = tf.shape.iter().collect();
shape[axis] = pulse.into();
Ok(PulsedFact { datum_type, shape: shape.into(), axis, dim: len.clone(), delay: 0 })
}
pub fn pulse(&self) -> usize {
self.shape[self.axis].to_usize().expect("Pulse should be an integer. This is a tract bug.")
}
pub fn to_pulse_fact(&self) -> TypedFact {
self.datum_type.fact(self.shape.clone())
}
pub fn streaming_shape(&self) -> Vec<TDim> {
self.shape
.iter()
.enumerate()
.map(|(ix, d)| if ix == self.axis { self.dim.clone() } else { d })
.collect()
}
pub fn to_streaming_fact(&self) -> TypedFact {
let mut info = self.to_pulse_fact();
info.shape.set(self.axis, self.dim.clone());
info
}
}
impl fmt::Debug for PulsedFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
use tract_itertools::Itertools;
write!(
fmt,
"{},{:?} [pulse axis:{} ∂:{} full dim:{}]",
self.shape.iter().join(","),
self.datum_type,
self.axis,
self.delay,
self.dim
)
}
}
impl Fact for PulsedFact {
fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
Ok(Cow::Owned(self.into()))
}
fn same_as(&self, other: &dyn Fact) -> bool {
if let Some(other) = other.downcast_ref::<PulsedFact>() {
other == self
} else {
false
}
}
fn compatible_with(&self, other: &dyn Fact) -> bool {
self.same_as(other)
}
fn datum_type(&self) -> Option<DatumType> {
Some(self.datum_type)
}
}
impl From<PulsedFact> for TypedFact {
fn from(fact: PulsedFact) -> TypedFact {
fact.datum_type.fact(fact.shape)
}
}
impl<'a> From<&'a PulsedFact> for TypedFact {
fn from(fact: &'a PulsedFact) -> TypedFact {
fact.datum_type.fact(fact.shape.clone())
}
}