use crate::ops::source::Source;
use std::collections::HashMap;
use crate::model::dsl::*;
use crate::ops::prelude::*;
pub mod delay;
#[derive(Clone, Debug, PartialEq)]
pub struct PulsedTensorFact {
pub dt: DatumType,
pub shape: TVec<usize>,
pub axis: usize,
pub dim: TDim,
pub delay: usize,
}
impl PulsedTensorFact {
pub fn from_tensor_fact_pulse(tf: &TensorFact, pulse: usize) -> TractResult<PulsedTensorFact> {
let dt = tf
.datum_type
.concretize()
.ok_or("Can not use pulse a tensor with no type")?;
let axis = tf
.stream_info()?
.ok_or("Can not pulse a tensor with no streaming dim")?
.axis;
let shape = tf
.shape
.concretize()
.ok_or("Can not pulse a tensor with unknown shape")?;
let dim = shape[axis];
let shape = shape
.iter()
.enumerate()
.map(|(ix, &d)| {
if ix == axis {
Ok(pulse)
} else {
d.to_integer().map(|d| d as usize)
}
})
.collect::<TractResult<_>>()?;
Ok(PulsedTensorFact {
dt,
shape,
axis,
dim,
delay: 0,
})
}
pub fn pulse(&self) -> usize {
self.shape[self.axis]
}
pub fn to_pulse_fact(&self) -> TensorFact {
TensorFact::dt_shape(self.dt, self.shape.clone())
}
pub fn streaming_shape(&self) -> Vec<TDim> {
self.shape
.iter()
.enumerate()
.map(|(ix, &d)| {
if ix == self.axis {
self.dim
} else {
d.to_dim()
}
})
.collect()
}
pub fn to_streaming_fact(&self) -> TensorFact {
TensorFact::dt_shape(self.dt, self.streaming_shape())
}
}
#[derive(Clone, Debug, new)]
pub struct PulsifiedOp {
pub op: Box<Op>,
pub outputs: TVec<PulsedTensorFact>,
}
pub fn pulsify(
old: &Model,
pulse: usize,
) -> TractResult<(Model, PulsedTensorFact, PulsedTensorFact)> {
let mut p_model = PulsifiedModel::new(old, pulse)?;
let in_id = p_model.model.inputs()?[0];
let out_id = p_model.model.outputs()?[0];
let in_fact: PulsedTensorFact = p_model.facts.remove(&in_id).unwrap();
let out_fact: PulsedTensorFact = p_model.facts.remove(&out_id).unwrap();
Ok((p_model.model, in_fact, out_fact))
}
#[derive(Clone, Debug)]
struct PulsifiedModel {
model: Model,
facts: HashMap<OutletId, PulsedTensorFact>,
}
impl PulsifiedModel {
fn new(old: &Model, pulse: usize) -> TractResult<PulsifiedModel> {
let mut model = Model::default();
let mut mapping: HashMap<OutletId, OutletId> = HashMap::new();
let mut facts: HashMap<OutletId, PulsedTensorFact> = HashMap::new();
for old_id in old.eval_order()? {
if old.node(old_id).op_as::<Source>().is_some() {
let node = old.node(old_id);
let pulsed_fact =
PulsedTensorFact::from_tensor_fact_pulse(&node.outputs[0].fact, pulse)?;
let id = model.add_source_fact(node.name.clone(), pulsed_fact.to_pulse_fact())?;
facts.insert(OutletId::new(id, 0), pulsed_fact);
mapping.insert(OutletId::new(old_id, 0), OutletId::new(id, 0));
} else {
let pulsed_chain = {
let inputs = old
.node(old_id)
.inputs
.iter()
.map(|i| &facts[&mapping[i]])
.collect();
trace!(" inputs: {:?}", inputs);
old.node(old_id).op().pulsify(inputs)?
};
let mut previous = None;
let count = pulsed_chain.len();
for (ix, pulsed) in pulsed_chain.into_iter().enumerate() {
let PulsifiedOp { op, outputs } = pulsed;
let name = if ix == count - 1 {
old.node(old_id).name.clone()
} else {
format!("{}#{}", old.node(old_id).name, ix)
};
let new_id = model.add_node(name, op)?;
if let Some(prev) = previous {
model.add_edge(OutletId::new(prev, 0), InletId::new(new_id, 0))?;
} else {
for (ix, input) in old.node(old_id).inputs.iter().enumerate() {
model.add_edge(mapping[&input], InletId::new(new_id, ix))?;
}
};
previous = Some(new_id);
for (ix, output_fact) in outputs.into_iter().enumerate() {
model.set_fact(OutletId::new(new_id, ix), output_fact.to_pulse_fact())?;
facts.insert(OutletId::new(new_id, ix), output_fact);
mapping.insert(OutletId::new(old_id, ix), OutletId::new(new_id, ix));
}
}
}
}
Ok(PulsifiedModel { model, facts })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_source_must_stream() {
let mut model = Model::default();
let _a = model
.add_source_fact("a", TensorFact::dt_shape(DatumType::F32, vec![1, 2, 3]))
.unwrap();
assert!(PulsifiedModel::new(&model, 4).is_err());
let mut model = Model::default();
let _a = model
.add_source_fact(
"a",
TensorFact::dt_shape(DatumType::F32, vec![1.to_dim(), TDim::s(), 3.to_dim()]),
)
.unwrap();
let mut pulse = PulsifiedModel::new(&model, 4).unwrap();
pulse.model.analyse().unwrap();
assert_eq!(
pulse.model.fact(OutletId::new(0, 0)).unwrap(),
&TensorFact::dt_shape(DatumType::F32, vec!(1, 4, 3))
);
}
#[test]
fn test_immediate() {
let mut model = Model::default();
let _a = model
.add_source_fact(
"a",
TensorFact::dt_shape(DatumType::F32, vec![TDim::s(), 2.to_dim(), 3.to_dim()]),
)
.unwrap();
let pulse = PulsifiedModel::new(&model, 4).unwrap();
assert_eq!(
pulse.model.input_fact().unwrap(),
&TensorFact::dt_shape(DatumType::F32, vec!(4, 2, 3))
);
assert_eq!(
pulse.model.output_fact().unwrap(),
&TensorFact::dt_shape(DatumType::F32, vec!(4, 2, 3))
);
}
#[test]
fn test_simple_conv() {
use crate::ops::nn::*;
use ndarray::*;
let mut model = Model::default();
let ker = model
.add_const("kernel", arr3(&[[[0.5f32, 1.0, -0.1]]]).into())
.unwrap();
let _ = model
.add_source_fact("a", TensorFact::shape(shapefact!(1, 1, S)))
.unwrap();
let conv = model.chain("conv", Box::new(Conv::default())).unwrap();
model
.add_edge(OutletId::new(ker, 0), InletId::new(conv, 1))
.unwrap();
model.analyse().unwrap();
assert_eq!(model.nodes().len(), 3);
let input = [1.0f32, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
let t_input = Tensor::from(arr3(&[[input]]));
let model = model.into_optimized().unwrap();
assert_eq!(model.nodes().len(), 2);
let plan = crate::plan::SimplePlan::new(&model).unwrap();
let outputs = plan.run(tvec!(t_input.clone())).unwrap();
let pulse = 4;
let pulsed = PulsifiedModel::new(&model, pulse).unwrap();
assert_eq!(pulsed.model.nodes().len(), 3);
assert_eq!(pulsed.facts[&OutletId::new(2, 0)].delay, 2);
let pulsed_plan = crate::plan::SimplePlan::new(pulsed.model).unwrap();
let mut state = crate::plan::SimpleState::new(&pulsed_plan).unwrap();
let mut got: Vec<f32> = vec![];
for p in 0..(input.len() / pulse) {
let chunk = &input[(p * pulse)..((p + 1) * pulse)];
let mut outputs = state
.run(tvec!(ndarray::Array::from_shape_vec(
(1usize, 1, 4),
chunk.to_vec()
)
.unwrap()
.into()))
.unwrap();
got.extend(outputs.remove(0).to_array_view::<f32>().unwrap().iter());
}
assert_eq!(
&got[2..],
outputs[0]
.to_array_view::<f32>()
.unwrap()
.as_slice()
.unwrap()
);
}
}