tract_hir/ops/
source.rs

1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::ops::source::{SourceState, TypedSource};
5
6#[derive(Debug, Clone, new, Hash)]
7pub struct Source;
8
9
10
11impl Op for Source {
12    fn name(&self) -> StaticName {
13        "Source".into()
14    }
15
16    not_a_typed_op!();
17}
18
19impl EvalOp for Source {
20    fn is_stateless(&self) -> bool {
21        false
22    }
23    fn state(
24        &self,
25        _session: &mut SessionState,
26        node_id: usize,
27    ) -> TractResult<Option<Box<dyn OpState>>> {
28        Ok(Some(Box::new(SourceState(node_id))))
29    }
30}
31
32impl InferenceRulesOp for Source {
33    /// Registers the inference rules of the operator.
34    fn rules<'r, 'p: 'r, 's: 'r>(
35        &'s self,
36        _s: &mut Solver<'r>,
37        inputs: &'p [TensorProxy],
38        outputs: &'p [TensorProxy],
39    ) -> InferenceResult {
40        check_input_arity(inputs, 0)?;
41        check_output_arity(outputs, 1)?;
42        Ok(())
43    }
44
45    as_op!();
46
47    fn to_typed(
48        &self,
49        _source: &InferenceModel,
50        node: &InferenceNode,
51        target: &mut TypedModel,
52        _mapping: &HashMap<OutletId, OutletId>,
53    ) -> TractResult<TVec<OutletId>> {
54        if let Ok(fact) = TypedFact::try_from(&node.outputs[0].fact) {
55            target.wire_node(&*node.name, TypedSource::new(fact), &[])
56        } else {
57            bail!("Source node without a determined fact. Help: provide explicit input facts to your model.")
58        }
59    }
60}