1use crate::internal::*;
2
3#[derive(Debug, Clone, new)]
4pub struct SourceState(pub usize);
5trivial_op_state_freeze!(SourceState);
6
7impl OpState for SourceState {
8 fn eval(
9 &mut self,
10 session: &mut TurnState,
11 _op: &dyn Op,
12 _inputs: TVec<TValue>,
13 ) -> TractResult<TVec<TValue>> {
14 Ok(tvec!(
15 session
16 .values
17 .get(self.0)
18 .and_then(|v| v.as_ref())
19 .and_then(|vs| vs.first().cloned())
20 .with_context(|| format!("Input for node {} is missing", self.0))?
21 ))
22 }
23}
24
25#[derive(Debug, Clone, new, Hash)]
26pub struct TypedSource {
27 pub fact: TypedFact,
28}
29
30impl Op for TypedSource {
31 fn name(&self) -> StaticName {
32 "Source".into()
33 }
34 op_as_typed_op!();
35}
36
37impl EvalOp for TypedSource {
38 fn is_stateless(&self) -> bool {
39 false
40 }
41
42 fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
43 Ok(Some(Box::new(SourceState(node_id))))
44 }
45}
46
47impl TypedOp for TypedSource {
48 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
49 Ok(tvec!(self.fact.clone()))
50 }
51
52 fn change_axes(
53 &self,
54 model: &TypedModel,
55 node: &TypedNode,
56 _io: InOut,
57 change: &AxisOp,
58 ) -> TractResult<Option<AxisChangeConsequence>> {
59 let mut fact = self.fact.clone();
60 change.change_shape(&mut fact.shape, false)?;
61 Ok(Some(AxisChangeConsequence::new(
62 model,
63 node,
64 Some(Box::new(TypedSource::new(fact))),
65 change,
66 )))
67 }
68
69 fn concretize_dims(
70 &self,
71 _source: &TypedModel,
72 node: &TypedNode,
73 target: &mut TypedModel,
74 _mapping: &HashMap<OutletId, OutletId>,
75 values: &SymbolValues,
76 ) -> TractResult<TVec<OutletId>> {
77 let shape: TVec<_> = self.fact.shape.iter().map(|d| d.eval(values)).collect();
78 target.wire_node(&node.name, Self { fact: self.fact.datum_type.fact(&*shape) }, &[])
79 }
80
81 as_op!();
82}