Skip to main content

tract_core/ops/
source.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, new)]
4pub struct SourceState(pub usize);
5trivial_op_state_freeze!(SourceState);
6
7impl OpState for SourceState {
8    fn eval(
9        &mut self,
10        session: &mut TurnState,
11        _op: &dyn Op,
12        _inputs: TVec<TValue>,
13    ) -> TractResult<TVec<TValue>> {
14        Ok(tvec!(
15            session
16                .values
17                .get(self.0)
18                .and_then(|v| v.as_ref())
19                .and_then(|vs| vs.first().cloned())
20                .with_context(|| format!("Input for node {} is missing", self.0))?
21        ))
22    }
23}
24
25#[derive(Debug, Clone, new, Hash)]
26pub struct TypedSource {
27    pub fact: TypedFact,
28}
29
30impl Op for TypedSource {
31    fn name(&self) -> StaticName {
32        "Source".into()
33    }
34    op_as_typed_op!();
35}
36
37impl EvalOp for TypedSource {
38    fn is_stateless(&self) -> bool {
39        false
40    }
41
42    fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
43        Ok(Some(Box::new(SourceState(node_id))))
44    }
45}
46
47impl TypedOp for TypedSource {
48    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
49        Ok(tvec!(self.fact.clone()))
50    }
51
52    fn change_axes(
53        &self,
54        model: &TypedModel,
55        node: &TypedNode,
56        _io: InOut,
57        change: &AxisOp,
58    ) -> TractResult<Option<AxisChangeConsequence>> {
59        let mut fact = self.fact.clone();
60        change.change_shape(&mut fact.shape, false)?;
61        Ok(Some(AxisChangeConsequence::new(
62            model,
63            node,
64            Some(Box::new(TypedSource::new(fact))),
65            change,
66        )))
67    }
68
69    fn concretize_dims(
70        &self,
71        _source: &TypedModel,
72        node: &TypedNode,
73        target: &mut TypedModel,
74        _mapping: &HashMap<OutletId, OutletId>,
75        values: &SymbolValues,
76    ) -> TractResult<TVec<OutletId>> {
77        let shape: TVec<_> = self.fact.shape.iter().map(|d| d.eval(values)).collect();
78        target.wire_node(&node.name, Self { fact: self.fact.datum_type.fact(&*shape) }, &[])
79    }
80
81    as_op!();
82}