tract_core/ops/
source.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, new)]
4pub struct SourceState(pub usize);
5trivial_op_state_freeeze!(SourceState);
6
7impl OpState for SourceState {
8    fn eval(
9        &mut self,
10        session: &mut SessionState,
11        _op: &dyn Op,
12        _inputs: TVec<TValue>,
13    ) -> TractResult<TVec<TValue>> {
14        Ok(tvec!(session
15            .inputs
16            .get(&self.0)
17            .with_context(|| format!("Input for node {} is missing", self.0))?
18            .clone()))
19    }
20}
21
22#[derive(Debug, Clone, new, Hash)]
23pub struct TypedSource {
24    pub fact: TypedFact,
25}
26
27
28
29impl Op for TypedSource {
30    fn name(&self) -> Cow<str> {
31        "Source".into()
32    }
33    op_as_typed_op!();
34}
35
36impl EvalOp for TypedSource {
37    fn is_stateless(&self) -> bool {
38        false
39    }
40
41    fn state(
42        &self,
43        _session: &mut SessionState,
44        node_id: usize,
45    ) -> TractResult<Option<Box<dyn OpState>>> {
46        Ok(Some(Box::new(SourceState(node_id))))
47    }
48}
49
50impl TypedOp for TypedSource {
51    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
52        Ok(tvec!(self.fact.clone()))
53    }
54
55    fn change_axes(
56        &self,
57        model: &TypedModel,
58        node: &TypedNode,
59        _io: InOut,
60        change: &AxisOp,
61    ) -> TractResult<Option<AxisChangeConsequence>> {
62        let mut fact = self.fact.clone();
63        change.change_shape(&mut fact.shape, false)?;
64        Ok(Some(AxisChangeConsequence::new(
65            model,
66            node,
67            Some(Box::new(TypedSource::new(fact))),
68            change,
69        )))
70    }
71
72    fn concretize_dims(
73        &self,
74        _source: &TypedModel,
75        node: &TypedNode,
76        target: &mut TypedModel,
77        _mapping: &HashMap<OutletId, OutletId>,
78        values: &SymbolValues,
79    ) -> TractResult<TVec<OutletId>> {
80        let shape: TVec<_> = self.fact.shape.iter().map(|d| d.eval(values)).collect();
81        target.wire_node(&node.name, Self { fact: self.fact.datum_type.fact(&*shape) }, &[])
82    }
83
84    as_op!();
85}