1use crate::internal::*;
2
3#[derive(Debug, Clone, new)]
4pub struct SourceState(pub usize);
5trivial_op_state_freeeze!(SourceState);
6
7impl OpState for SourceState {
8 fn eval(
9 &mut self,
10 session: &mut SessionState,
11 _op: &dyn Op,
12 _inputs: TVec<TValue>,
13 ) -> TractResult<TVec<TValue>> {
14 Ok(tvec!(session
15 .inputs
16 .get(&self.0)
17 .with_context(|| format!("Input for node {} is missing", self.0))?
18 .clone()))
19 }
20}
21
22#[derive(Debug, Clone, new, Hash)]
23pub struct TypedSource {
24 pub fact: TypedFact,
25}
26
27
28
29impl Op for TypedSource {
30 fn name(&self) -> Cow<str> {
31 "Source".into()
32 }
33 op_as_typed_op!();
34}
35
36impl EvalOp for TypedSource {
37 fn is_stateless(&self) -> bool {
38 false
39 }
40
41 fn state(
42 &self,
43 _session: &mut SessionState,
44 node_id: usize,
45 ) -> TractResult<Option<Box<dyn OpState>>> {
46 Ok(Some(Box::new(SourceState(node_id))))
47 }
48}
49
50impl TypedOp for TypedSource {
51 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
52 Ok(tvec!(self.fact.clone()))
53 }
54
55 fn change_axes(
56 &self,
57 model: &TypedModel,
58 node: &TypedNode,
59 _io: InOut,
60 change: &AxisOp,
61 ) -> TractResult<Option<AxisChangeConsequence>> {
62 let mut fact = self.fact.clone();
63 change.change_shape(&mut fact.shape, false)?;
64 Ok(Some(AxisChangeConsequence::new(
65 model,
66 node,
67 Some(Box::new(TypedSource::new(fact))),
68 change,
69 )))
70 }
71
72 fn concretize_dims(
73 &self,
74 _source: &TypedModel,
75 node: &TypedNode,
76 target: &mut TypedModel,
77 _mapping: &HashMap<OutletId, OutletId>,
78 values: &SymbolValues,
79 ) -> TractResult<TVec<OutletId>> {
80 let shape: TVec<_> = self.fact.shape.iter().map(|d| d.eval(values)).collect();
81 target.wire_node(&node.name, Self { fact: self.fact.datum_type.fact(&*shape) }, &[])
82 }
83
84 as_op!();
85}