Skip to main content

tract_core/
plan.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6
7use crate::internal::*;
8use crate::model::{Fact, Graph, OutletId};
9use crate::ops::FrozenOpState;
10use crate::ops::konst::Const;
11use crate::runtime::RunOptions;
12
13use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
14
15pub struct TurnState {
16    pub resolved_symbols: SymbolValues,
17    pub scenario: Option<usize>,
18    pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
19    pub scratch_extensions: anymap3::Map,
20    pub values: Vec<Option<TVec<TValue>>>,
21}
22
23impl Default for TurnState {
24    fn default() -> Self {
25        TurnState {
26            resolved_symbols: SymbolValues::default(),
27            scenario: None,
28            cached_mmm_scratch_space: None.into(),
29            scratch_extensions: anymap3::Map::new(),
30            values: vec![],
31        }
32    }
33}
34
35impl Clone for TurnState {
36    fn clone(&self) -> Self {
37        TurnState {
38            resolved_symbols: self.resolved_symbols.clone(),
39            scenario: self.scenario,
40            cached_mmm_scratch_space: None.into(),
41            scratch_extensions: anymap3::Map::new(),
42            values: vec![],
43        }
44    }
45}
46
47pub trait SessionStateHandler: Send + Sync + Debug {
48    fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
49    fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50}
51
52impl Debug for TurnState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "SessionState({:?})", self.resolved_symbols)
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct SimplePlan<F, O>
60where
61    F: Fact + Clone + 'static,
62    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
63{
64    pub(crate) model: Arc<Graph<F, O>>,
65    outputs: Vec<OutletId>,
66    order: Vec<usize>,
67    flush_lists: Vec<TVec<usize>>,
68    has_unresolved_symbols: bool,
69    symbols: Vec<Symbol>,
70    executor: Option<Executor>,
71    session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
72}
73
74impl<F, O> SimplePlan<F, O>
75where
76    F: Fact + Clone + 'static,
77    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
78{
79    /// This contructor returns a plan that will compute all the model default outputs in one pass.
80    pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
81        let model = model.into();
82        Self::build(model, &RunOptions::default()).map(Arc::new)
83    }
84
85    /// This contructor returns a plan that will compute all the model default outputs in one pass.
86    pub fn new_with_options(
87        model: impl Into<Arc<Graph<F, O>>>,
88        options: &RunOptions,
89    ) -> TractResult<Arc<SimplePlan<F, O>>> {
90        let model = model.into();
91        Self::build(model, options).map(Arc::new)
92    }
93
94    /// This contructor returns a plan that will compute the specified output.
95    #[deprecated]
96    pub fn new_for_output(
97        model: Graph<F, O>,
98        output: OutletId,
99    ) -> TractResult<Arc<SimplePlan<F, O>>> {
100        #[allow(deprecated)]
101        Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
102            .map(Arc::new)
103    }
104
105    /// This contructor returns a plan that will compute all specified outputs in one pass.
106    #[deprecated]
107    pub fn new_for_outputs(
108        model: impl Into<Arc<Graph<F, O>>>,
109        outputs: &[OutletId],
110    ) -> TractResult<Arc<SimplePlan<F, O>>> {
111        #[allow(deprecated)]
112        Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
113    }
114
115    pub fn with_session_handler<H: SessionStateHandler + 'static>(
116        mut self,
117        session_handler: H,
118    ) -> Self {
119        self.session_handler = Some(Arc::new(session_handler));
120        self
121    }
122
123    #[deprecated]
124    pub fn new_for_outputs_and_deps(
125        model: impl Into<Arc<Graph<F, O>>>,
126        outputs: &[OutletId],
127        deps: &[(usize, usize)],
128    ) -> TractResult<Arc<SimplePlan<F, O>>> {
129        #[allow(deprecated)]
130        Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
131            .map(Arc::new)
132    }
133
134    pub fn build(
135        model: impl Into<Arc<Graph<F, O>>>,
136        options: &RunOptions,
137    ) -> TractResult<SimplePlan<F, O>> {
138        let model = model.into();
139        let outputs = model.outputs.clone();
140        #[allow(deprecated)]
141        Self::build_with_outputs_and_deps(model, &outputs, &[], options)
142    }
143
144    #[deprecated]
145    pub fn build_with_outputs_and_deps(
146        model: impl Into<Arc<Graph<F, O>>>,
147        outputs: &[OutletId],
148        deps: &[(usize, usize)],
149        options: &RunOptions,
150    ) -> TractResult<SimplePlan<F, O>> {
151        let model = model.into();
152        let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
153        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
154        let mut order = if options.skip_order_opt_ram {
155            eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
156        } else {
157            eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
158        };
159        order.retain(|node| !model.node(*node).op_is::<Const>());
160        let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
161
162        #[allow(clippy::mutable_key_type)]
163        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
164        for node in &model.nodes {
165            for output in &node.outputs {
166                if let Ok(fact) = output.fact.to_typed_fact() {
167                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
168                }
169            }
170        }
171        Ok(SimplePlan {
172            model,
173            order,
174            flush_lists,
175            outputs: outputs.to_vec(),
176            has_unresolved_symbols: !symbols.is_empty(),
177            symbols: symbols.into_iter().collect(),
178            executor: options.executor.clone(),
179            session_handler: None,
180        })
181    }
182
183    pub fn order_without_consts(&self) -> &[usize] {
184        &self.order
185    }
186
187    pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
188        let mut state = self.spawn()?;
189        state.run(inputs)
190    }
191
192    pub fn model(&self) -> &Graph<F, O> {
193        self.model.borrow()
194    }
195
196    pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
197        SimpleState::new(self)
198    }
199}
200
201#[derive(Clone, Debug)]
202pub struct SimpleState<F, O>
203where
204    F: Fact + Clone + 'static,
205    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
206{
207    pub(crate) plan: Arc<SimplePlan<F, O>>,
208    pub op_states: Vec<Option<Box<dyn OpState>>>,
209    pub turn_state: TurnState,
210}
211
212impl<F, O> SimpleState<F, O>
213where
214    F: Fact + Clone + 'static,
215    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
216{
217    pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
218        let plan = Arc::clone(plan);
219        let turn = TurnState::default();
220        let model = plan.model();
221        let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
222        let mut state = SimpleState { plan, op_states: states, turn_state: turn };
223        state.reset_op_states()?;
224        Ok(state)
225    }
226
227    pub fn new_from_inputs(
228        plan: &Arc<SimplePlan<F, O>>,
229        inputs: TVec<TValue>,
230    ) -> TractResult<SimpleState<F, O>> {
231        let mut state = SimpleState::new(plan)?;
232        state.set_inputs(inputs)?;
233        state.resolve_symbols_with_states()?;
234
235        Ok(state)
236    }
237
238    fn ready_turn(&mut self) {
239        if self.turn_state.values.len() == 0 {
240            self.turn_state.values = vec![None; self.plan.model.nodes().len()];
241            for node in &self.plan.model.nodes {
242                if let Some(k) = node.op_as::<Const>() {
243                    self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
244                }
245            }
246        }
247    }
248    /// Reset wires state.
249    pub fn reset_turn(&mut self) -> TractResult<()> {
250        self.reset_turn_keep_symbols();
251        self.turn_state.resolved_symbols = SymbolValues::default();
252        Ok(())
253    }
254
255    /// Like [`reset_turn`] but keeps the resolved symbols (and scenario). Used by
256    /// `Scan`/`Loop` bodies, whose shapes are constant across iterations: it lets
257    /// the body resolve its symbols once and skip the per-iteration re-resolution
258    /// the full `reset_turn` + `run` cycle would otherwise force.
259    pub(crate) fn reset_turn_keep_symbols(&mut self) {
260        for node in &self.plan.order {
261            self.turn_state.values[*node] = None;
262        }
263    }
264
265    /// Clear resolved symbols (and scenario) without touching node values. Used at
266    /// the start of a fresh `Scan` evaluation, since the body state persists across
267    /// outer calls and a previous call may have left stale symbol resolutions.
268    pub(crate) fn clear_resolved_symbols(&mut self) {
269        self.turn_state.resolved_symbols = SymbolValues::default();
270        self.turn_state.scenario = None;
271    }
272
273    /// Reset op inner state.
274    fn reset_op_states(&mut self) -> TractResult<()> {
275        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
276        for (ix, n) in plan.model.nodes.iter().enumerate() {
277            states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
278        }
279        Ok(())
280    }
281
282    pub(crate) fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
283        for state in self
284            .op_states
285            .iter_mut()
286            .filter_map(Option::as_mut)
287            .filter(|s| s.init_tensor_fact().is_some())
288        {
289            state.resolve_symbols(&mut self.turn_state)?;
290        }
291        Ok(())
292    }
293
294    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
295        self.run_plan_with_eval(inputs, self::eval)
296    }
297
298    pub fn exec(&mut self) -> TractResult<()> {
299        self.exec_plan_with_eval(self::eval)
300    }
301
302    pub fn run_plan_with_eval<Eval, E>(
303        &mut self,
304        inputs: TVec<TValue>,
305        eval: Eval,
306    ) -> TractResult<TVec<TValue>>
307    where
308        Eval: for<'a, 'b, 'c> FnMut(
309            &'a mut TurnState,
310            Option<&'b mut (dyn OpState + 'static)>,
311            &'c Node<F, O>,
312            TVec<TValue>,
313        ) -> Result<TVec<TValue>, E>,
314        E: Into<anyhow::Error> + Send + Sync + 'static,
315    {
316        self.set_inputs(inputs)?;
317        self.resolve_symbols_with_states()?;
318        self.exec_plan_with_eval(eval)?;
319        let outputs = self.outputs()?;
320        self.reset_turn()?;
321        Ok(outputs)
322    }
323
324    pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
325    where
326        Eval: for<'a, 'b, 'c> FnMut(
327            &'a mut TurnState,
328            Option<&'b mut (dyn OpState + 'static)>,
329            &'c Node<F, O>,
330            TVec<TValue>,
331        ) -> Result<TVec<TValue>, E>,
332        E: Into<anyhow::Error> + Send + Sync + 'static,
333    {
334        if let Some(executor) = self.plan().executor.as_ref() {
335            tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
336                self.do_exec_plan_with_eval(eval)
337            })
338        } else {
339            self.do_exec_plan_with_eval(eval)
340        }
341    }
342
343    fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
344    where
345        Eval: for<'a, 'b, 'c> FnMut(
346            &'a mut TurnState,
347            Option<&'b mut (dyn OpState + 'static)>,
348            &'c Node<F, O>,
349            TVec<TValue>,
350        ) -> Result<TVec<TValue>, E>,
351        E: Into<anyhow::Error> + Send + Sync + 'static,
352    {
353        {
354            self.ready_turn();
355            self.plan
356                .session_handler
357                .as_ref()
358                .map(|it| it.before_plan_eval(&mut self.turn_state))
359                .transpose()?;
360
361            let mut syms_done = !self.plan.has_unresolved_symbols
362                || self
363                    .plan
364                    .symbols
365                    .iter()
366                    .all(|s| self.turn_state.resolved_symbols.get(s).is_some());
367
368            for (step, n) in self.plan.order.iter().enumerate() {
369                let node = self.plan.model.node(*n);
370                trace!("Running step {step}, node {node}");
371                let mut inputs: TVec<TValue> = tvec![];
372                for i in &node.inputs {
373                    trace!("  use input {i:?}");
374                    let prec_node = self.plan.model.node(i.node);
375                    let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
376                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
377                    })?;
378                    inputs.push(prec[i.slot].clone())
379                }
380
381                for flush in &self.plan.flush_lists[step] {
382                    trace!("  Ran {} can now flush {}", node, self.plan.model.node(*flush));
383                    self.turn_state.values[*flush] = None;
384                }
385
386                if cfg!(debug_assertions) {
387                    let facts = self.plan.model.node_input_facts(node.id)?;
388                    if facts.len() != inputs.len() {
389                        bail!(
390                            "Evaluating {}: expected {} inputs, got {}",
391                            node,
392                            facts.len(),
393                            inputs.len()
394                        );
395                    }
396                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
397                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
398                            bail!(
399                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
400                                node,
401                                ix,
402                                f,
403                                v
404                            );
405                        }
406                    }
407                }
408
409                let vs = eval(
410                    &mut self.turn_state,
411                    self.op_states[node.id].as_deref_mut(),
412                    node,
413                    inputs,
414                )
415                .map_err(|e| e.into())?;
416
417                if !syms_done && self.plan.has_unresolved_symbols {
418                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
419                        if let Ok(f) = o.fact.to_typed_fact() {
420                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
421                                Self::resolve(
422                                    &mut self.turn_state,
423                                    dim_abstract,
424                                    *dim_concrete as i64,
425                                )?;
426                            }
427                        }
428                    }
429                    if self
430                        .plan
431                        .symbols
432                        .iter()
433                        .all(|s| self.turn_state.resolved_symbols.get(s).is_some())
434                    {
435                        syms_done = true;
436                    }
437                }
438                if cfg!(debug_assertions) {
439                    let facts = self.plan.model.node_output_facts(node.id)?;
440                    if facts.len() != vs.len() {
441                        bail!(
442                            "Evaluating {}: expected {} outputs, got {}",
443                            node,
444                            facts.len(),
445                            vs.len()
446                        );
447                    }
448                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
449                        if node.outputs[ix].successors.len() == 0 {
450                            continue;
451                        }
452                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
453                            bail!(
454                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
455                                node,
456                                ix,
457                                f,
458                                v
459                            );
460                        }
461                    }
462                }
463
464                self.turn_state.values[node.id] = Some(vs);
465            }
466            self.plan
467                .session_handler
468                .as_ref()
469                .map(|it| it.after_plan_eval(&mut self.turn_state))
470                .transpose()?;
471        }
472        Ok(())
473    }
474
475    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
476        ensure!(
477            inputs.len() == self.model().inputs.len(),
478            "Wrong number of inputs for model. Expected {} got {}",
479            self.model().inputs.len(),
480            inputs.len()
481        );
482
483        for (ix, t) in inputs.into_iter().enumerate() {
484            self.set_input(ix, t)?
485        }
486        Ok(())
487    }
488
489    /// Like [`set_inputs`] but drains the caller's buffer (leaving it empty with
490    /// its capacity intact) instead of consuming it, so a repeated caller (a
491    /// `Scan` body loop) can reuse one allocation across iterations.
492    pub(crate) fn set_inputs_drain(&mut self, inputs: &mut TVec<TValue>) -> TractResult<()> {
493        ensure!(
494            inputs.len() == self.model().inputs.len(),
495            "Wrong number of inputs for model. Expected {} got {}",
496            self.model().inputs.len(),
497            inputs.len()
498        );
499        for (ix, t) in inputs.drain(..).enumerate() {
500            self.set_input(ix, t)?
501        }
502        Ok(())
503    }
504
505    fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
506        if let TDim::Sym(sym) = expression
507            && state.resolved_symbols.get(sym).is_none()
508        {
509            state.resolved_symbols.set(sym, provided);
510            if state.scenario.is_none() {
511                let scope = sym.scope().with_context(|| {
512                    format!(
513                        "Symbol {sym:?} points to an invalid (dead ?) SymbolScope. \
514                         Make sure to create symbols using the model-managed SymbolScope."
515                    )
516                })?;
517                state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
518            }
519            return Ok(());
520        }
521        let expected = expression.eval(&state.resolved_symbols);
522        if let Some(x) = expected.as_i64()
523            && x != provided
524        {
525            bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
526        }
527        if expected.symbols().len() == 1 {
528            let sym = expected.symbols().into_iter().next().unwrap();
529            if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
530                debug!("Determined symbol {sym}={v}");
531                state.resolved_symbols.set(&sym, v.to_i64().unwrap());
532            }
533            if state.scenario.is_none() {
534                let scope = sym
535                    .scope()
536                    .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
537                state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
538            }
539        }
540        Ok(())
541    }
542
543    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
544        let outlet: OutletId = *self
545            .model()
546            .input_outlets()?
547            .get(input)
548            .with_context(|| format!("Invalid input id for model ({input})."))?;
549        if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
550            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
551                Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
552            }
553        }
554        let fact = self.plan.model.outlet_fact(outlet)?;
555        ensure!(
556            fact.matches(&t, Some(&self.turn_state.resolved_symbols))
557                .with_context(|| format!("Setting input {input}"))?,
558            "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
559        );
560        self.ready_turn();
561        self.turn_state.values[outlet.node] = Some(tvec!(t));
562        Ok(())
563    }
564
565    pub fn output(&self, id: usize) -> TractResult<&TValue> {
566        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
567            format!(
568                "Required output {}, only have {}",
569                id,
570                self.model().output_outlets().unwrap().len()
571            )
572        })?;
573        let value: &TValue = self
574            .turn_state
575            .values
576            .get(outlet.node)
577            .context("node id for output beyond node values array")?
578            .as_ref()
579            .context("node is not an output")?
580            .get(outlet.slot)
581            .context("slot id too high")?;
582        Ok(value)
583    }
584
585    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
586        let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
587        let mut v = tvec![];
588        for o in plan.outputs.iter() {
589            let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
590                format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
591            })?;
592            v.push(vs[o.slot].clone())
593        }
594        Ok(v)
595    }
596
597    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
598        self.turn_state.values[id] = Some(values);
599        Ok(())
600    }
601
602    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
603        self.set_values(id, tvec!(value))
604    }
605
606    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
607        let SimpleState { plan, turn_state, .. } = self;
608        let nodes = plan.model.nodes();
609        let node = &nodes[node];
610        let mut inputs: TVec<TValue> = tvec![];
611        for i in &node.inputs {
612            let prec_node = &nodes[i.node];
613            let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
614                format_err!("Computing {}, precursor {} not done.", node, prec_node)
615            })?;
616            inputs.push(prec[i.slot].clone())
617        }
618        Ok(inputs)
619    }
620
621    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
622        let inputs = self.prepare_inputs(node)?;
623        self.compute_one_with_inputs(node, inputs)
624    }
625
626    pub fn compute_one_with_inputs(
627        &mut self,
628        node: usize,
629        inputs: TVec<TValue>,
630    ) -> TractResult<()> {
631        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
632        let nodes = plan.model.nodes();
633        let node = &nodes[node];
634        let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
635        turn_state.values[node.id] = Some(vs);
636        Ok(())
637    }
638
639    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
640        let values = {
641            #[allow(clippy::needless_collect)] // clippy bug ?
642            let precs: Vec<usize> =
643                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
644            for i in precs.into_iter() {
645                if self.turn_state.values[i].is_none() {
646                    let _ = self.compute_recursively(i)?;
647                }
648            }
649            let mut inputs: TVec<TValue> = tvec![];
650            {
651                let node = &self.model().nodes()[node];
652                for i in &node.inputs {
653                    inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
654                }
655            }
656            let &mut Self {
657                op_states: ref mut states,
658                turn_state: ref mut session_state,
659                ref plan,
660                ..
661            } = self;
662            eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
663        };
664        self.turn_state.values[node] = Some(values);
665        Ok(self.turn_state.values[node].as_ref().unwrap())
666    }
667
668    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
669        let id = self.model().node_by_name(name)?.id;
670        Self::take(self, id)
671    }
672
673    pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
674        Ok(self.turn_state.values[id]
675            .take()
676            .ok_or_else(|| format_err!("Node is not computed"))?
677            .into_iter()
678            .map(|v| v.into_tensor())
679            .collect())
680    }
681
682    pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
683        &self.plan
684    }
685
686    pub fn model(&self) -> &Graph<F, O> {
687        &self.plan.model
688    }
689
690    pub fn freeze(&self) -> FrozenSimpleState<F, O> {
691        FrozenSimpleState {
692            plan: self.plan.clone(),
693            resolved_symbols: self.turn_state.resolved_symbols.clone(),
694            scenario: self.turn_state.scenario,
695            states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
696            values: self
697                .turn_state
698                .values
699                .iter()
700                .enumerate()
701                .map(|(ix, t)| {
702                    if self.model().nodes[ix].op_is::<Const>() {
703                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
704                    } else {
705                        None
706                    }
707                })
708                .collect(),
709        }
710    }
711
712    pub fn freeze_into(self) -> FrozenSimpleState<F, O> {
713        let plan = self.plan;
714        let model = &plan.model;
715        FrozenSimpleState {
716            resolved_symbols: self.turn_state.resolved_symbols,
717            scenario: self.turn_state.scenario,
718            states: self.op_states.into_iter().map(|s| s.map(|s| s.freeze_into())).collect(),
719            values: self
720                .turn_state
721                .values
722                .into_iter()
723                .enumerate()
724                .map(|(ix, t)| {
725                    if model.nodes[ix].op_is::<Const>() {
726                        t.map(|t| t.into_iter().map(|t| t.into_tensor()).collect())
727                    } else {
728                        None
729                    }
730                })
731                .collect(),
732            plan,
733        }
734    }
735}
736
737pub fn eval<F, O>(
738    session_state: &mut TurnState,
739    mut state: Option<&mut (dyn OpState + 'static)>,
740    node: &Node<F, O>,
741    input: TVec<TValue>,
742) -> TractResult<TVec<TValue>>
743where
744    F: Fact + Clone + 'static,
745    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
746{
747    // eprint!("{node} {input:?}");
748    #[allow(clippy::let_and_return)]
749    let r = match state {
750        Some(ref mut state) => state.eval(session_state, node.op(), input),
751        None => node.op().eval_with_session(node.id, session_state, input),
752    }
753    .with_context(|| format!("Evaluating {node}"));
754    // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
755    r
756}
757
758#[derive(Clone, Debug)]
759pub struct FrozenSimpleState<F, O>
760where
761    F: Fact + Clone + 'static,
762    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
763{
764    plan: Arc<SimplePlan<F, O>>,
765    pub resolved_symbols: SymbolValues,
766    pub scenario: Option<usize>,
767    pub states: Vec<Option<Box<dyn FrozenOpState>>>,
768    pub values: Vec<Option<TVec<Tensor>>>,
769}
770
771impl<F, O> FrozenSimpleState<F, O>
772where
773    F: Fact + Clone + 'static,
774    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
775{
776    pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
777        &self.plan
778    }
779
780    pub fn unfreeze(&self) -> SimpleState<F, O> {
781        SimpleState {
782            plan: self.plan.clone(),
783            turn_state: TurnState {
784                resolved_symbols: self.resolved_symbols.clone(),
785                scenario: self.scenario,
786                cached_mmm_scratch_space: None.into(),
787                scratch_extensions: anymap3::Map::new(),
788                values: self
789                    .values
790                    .iter()
791                    .map(|t| {
792                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
793                    })
794                    .collect(),
795            },
796            op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
797        }
798    }
799}
800
801#[cfg(test)]
802mod test {
803    use super::*;
804    fn is_send<T: Send>() {}
805    fn is_sync<T: Sync>() {}
806
807    #[test]
808    fn type_model_is_sync() {
809        is_sync::<TypedModel>();
810    }
811
812    #[test]
813    fn type_model_is_send() {
814        is_send::<TypedModel>();
815    }
816
817    #[test]
818    fn type_plan_is_send() {
819        is_send::<TypedSimplePlan>();
820    }
821
822    #[test]
823    fn type_plan_is_sync() {
824        is_sync::<TypedSimplePlan>();
825    }
826
827    #[test]
828    fn frozen_type_state_is_send() {
829        is_send::<TypedFrozenSimpleState>();
830    }
831}