tract_core/ops/memory/
force_eval.rs1use crate::internal::*;
2
3#[derive(Clone, Debug, PartialEq, Eq)]
4pub struct ForceEval {
5 pub slots: Vec<usize>,
6}
7
8impl ForceEval {
9 pub fn new(slots: Vec<usize>) -> ForceEval {
10 ForceEval { slots }
11 }
12}
13
14impl Op for ForceEval {
15 fn name(&self) -> StaticName {
16 "ForceEval".into()
17 }
18
19 fn info(&self) -> TractResult<Vec<String>> {
20 Ok(vec![format!("slots: {:?}", self.slots)])
21 }
22
23 impl_op_same_as!();
24 op_as_typed_op!();
25}
26
27impl EvalOp for ForceEval {
28 fn is_stateless(&self) -> bool {
29 true
30 }
31
32 fn state(
33 &self,
34 _session: &mut SessionState,
35 _node_id: usize,
36 ) -> TractResult<Option<Box<dyn OpState>>> {
37 Ok(None)
38 }
39
40 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
41 let max_slot_idx = self.slots.iter().copied().max().unwrap_or(0);
42 ensure!(inputs.len() > max_slot_idx, format!("Expected at least {} inputs given the slot indexes that needs to be forced eval: {:?}", max_slot_idx + 1, self.slots));
43 let outputs = inputs
44 .into_iter()
45 .enumerate()
46 .filter_map(|(idx, val)| if !self.slots.contains(&idx) { Some(val) } else { None })
47 .collect::<TVec<_>>();
48 Ok(outputs)
49 }
50}
51
52impl TypedOp for ForceEval {
53 as_op!();
54
55 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
56 let output_facts = inputs
57 .iter()
58 .enumerate()
59 .filter_map(
60 |(idx, fact)| {
61 if !self.slots.contains(&idx) {
62 Some((*fact).clone())
63 } else {
64 None
65 }
66 },
67 )
68 .collect::<TVec<_>>();
69 Ok(output_facts)
70 }
71}