Skip to main content

tract_core/
plan.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6use tract_data::itertools::Itertools;
7
8use crate::internal::*;
9use crate::model::{Fact, Graph, OutletId};
10use crate::ops::FrozenOpState;
11use crate::ops::konst::Const;
12use crate::runtime::RunOptions;
13
14use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
15
16pub struct TurnState {
17    pub resolved_symbols: SymbolValues,
18    pub scenario: Option<usize>,
19    pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
20    pub scratch_extensions: anymap3::Map,
21    pub values: Vec<Option<TVec<TValue>>>,
22}
23
24impl Default for TurnState {
25    fn default() -> Self {
26        TurnState {
27            resolved_symbols: SymbolValues::default(),
28            scenario: None,
29            cached_mmm_scratch_space: None.into(),
30            scratch_extensions: anymap3::Map::new(),
31            values: vec![],
32        }
33    }
34}
35
36impl Clone for TurnState {
37    fn clone(&self) -> Self {
38        TurnState {
39            resolved_symbols: self.resolved_symbols.clone(),
40            scenario: self.scenario,
41            cached_mmm_scratch_space: None.into(),
42            scratch_extensions: anymap3::Map::new(),
43            values: vec![],
44        }
45    }
46}
47
48pub trait SessionStateHandler: Send + Sync + Debug {
49    fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50    fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
51}
52
53impl Debug for TurnState {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "SessionState({:?})", self.resolved_symbols)
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct SimplePlan<F, O>
61where
62    F: Fact + Clone + 'static,
63    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
64{
65    pub(crate) model: Arc<Graph<F, O>>,
66    outputs: Vec<OutletId>,
67    order: Vec<usize>,
68    flush_lists: Vec<TVec<usize>>,
69    has_unresolved_symbols: bool,
70    executor: Option<Executor>,
71    session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
72}
73
74impl<F, O> SimplePlan<F, O>
75where
76    F: Fact + Clone + 'static,
77    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
78{
79    /// This contructor returns a plan that will compute all the model default outputs in one pass.
80    pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
81        let model = model.into();
82        Self::build(model, &RunOptions::default()).map(Arc::new)
83    }
84
85    /// This contructor returns a plan that will compute all the model default outputs in one pass.
86    pub fn new_with_options(
87        model: impl Into<Arc<Graph<F, O>>>,
88        options: &RunOptions,
89    ) -> TractResult<Arc<SimplePlan<F, O>>> {
90        let model = model.into();
91        Self::build(model, options).map(Arc::new)
92    }
93
94    /// This contructor returns a plan that will compute the specified output.
95    #[deprecated]
96    pub fn new_for_output(
97        model: Graph<F, O>,
98        output: OutletId,
99    ) -> TractResult<Arc<SimplePlan<F, O>>> {
100        #[allow(deprecated)]
101        Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
102            .map(Arc::new)
103    }
104
105    /// This contructor returns a plan that will compute all specified outputs in one pass.
106    #[deprecated]
107    pub fn new_for_outputs(
108        model: impl Into<Arc<Graph<F, O>>>,
109        outputs: &[OutletId],
110    ) -> TractResult<Arc<SimplePlan<F, O>>> {
111        #[allow(deprecated)]
112        Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
113    }
114
115    pub fn with_session_handler<H: SessionStateHandler + 'static>(
116        mut self,
117        session_handler: H,
118    ) -> Self {
119        self.session_handler = Some(Arc::new(session_handler));
120        self
121    }
122
123    #[deprecated]
124    pub fn new_for_outputs_and_deps(
125        model: impl Into<Arc<Graph<F, O>>>,
126        outputs: &[OutletId],
127        deps: &[(usize, usize)],
128    ) -> TractResult<Arc<SimplePlan<F, O>>> {
129        #[allow(deprecated)]
130        Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
131            .map(Arc::new)
132    }
133
134    pub fn build(
135        model: impl Into<Arc<Graph<F, O>>>,
136        options: &RunOptions,
137    ) -> TractResult<SimplePlan<F, O>> {
138        let model = model.into();
139        let outputs = model.outputs.clone();
140        #[allow(deprecated)]
141        Self::build_with_outputs_and_deps(model, &outputs, &[], options)
142    }
143
144    #[deprecated]
145    pub fn build_with_outputs_and_deps(
146        model: impl Into<Arc<Graph<F, O>>>,
147        outputs: &[OutletId],
148        deps: &[(usize, usize)],
149        options: &RunOptions,
150    ) -> TractResult<SimplePlan<F, O>> {
151        let model = model.into();
152        let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
153        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
154        let mut order = if options.skip_order_opt_ram {
155            eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
156        } else {
157            eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
158        };
159        order.retain(|node| !model.node(*node).op_is::<Const>());
160        let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
161
162        #[allow(clippy::mutable_key_type)]
163        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
164        for node in &model.nodes {
165            for output in &node.outputs {
166                if let Ok(fact) = output.fact.to_typed_fact() {
167                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
168                }
169            }
170        }
171        Ok(SimplePlan {
172            model,
173            order,
174            flush_lists,
175            outputs: outputs.to_vec(),
176            has_unresolved_symbols: !symbols.is_empty(),
177            executor: options.executor.clone(),
178            session_handler: None,
179        })
180    }
181
182    pub fn order_without_consts(&self) -> &[usize] {
183        &self.order
184    }
185
186    pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
187        let mut state = self.spawn()?;
188        state.run(inputs)
189    }
190
191    pub fn model(&self) -> &Graph<F, O> {
192        self.model.borrow()
193    }
194
195    pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
196        SimpleState::new(self)
197    }
198}
199
200#[derive(Clone, Debug)]
201pub struct SimpleState<F, O>
202where
203    F: Fact + Clone + 'static,
204    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
205{
206    pub(crate) plan: Arc<SimplePlan<F, O>>,
207    pub op_states: Vec<Option<Box<dyn OpState>>>,
208    pub turn_state: TurnState,
209}
210
211impl<F, O> SimpleState<F, O>
212where
213    F: Fact + Clone + 'static,
214    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
215{
216    pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
217        let plan = Arc::clone(plan);
218        let turn = TurnState::default();
219        let model = plan.model();
220        let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
221        let mut state = SimpleState { plan, op_states: states, turn_state: turn };
222        state.reset_op_states()?;
223        Ok(state)
224    }
225
226    pub fn new_from_inputs(
227        plan: &Arc<SimplePlan<F, O>>,
228        inputs: TVec<TValue>,
229    ) -> TractResult<SimpleState<F, O>> {
230        let mut state = SimpleState::new(plan)?;
231        state.set_inputs(inputs)?;
232        state.resolve_symbols_with_states()?;
233
234        Ok(state)
235    }
236
237    fn ready_turn(&mut self) {
238        if self.turn_state.values.len() == 0 {
239            self.turn_state.values = vec![None; self.plan.model.nodes().len()];
240            for node in &self.plan.model.nodes {
241                if let Some(k) = node.op_as::<Const>() {
242                    self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
243                }
244            }
245        }
246    }
247    /// Reset wires state.
248    pub fn reset_turn(&mut self) -> TractResult<()> {
249        for node in &self.plan.order {
250            self.turn_state.values[*node] = None;
251        }
252        self.turn_state.resolved_symbols = SymbolValues::default();
253        Ok(())
254    }
255
256    /// Reset op inner state.
257    fn reset_op_states(&mut self) -> TractResult<()> {
258        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
259        for (ix, n) in plan.model.nodes.iter().enumerate() {
260            states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
261        }
262        Ok(())
263    }
264
265    pub fn init_states(&mut self, state_init_tensors: &[TValue]) -> TractResult<()> {
266        let states_to_init = self
267            .op_states
268            .iter_mut()
269            .filter_map(Option::as_mut)
270            .filter(|s| s.init_tensor_fact().is_some())
271            .collect_vec();
272        ensure!(
273            states_to_init.len() == state_init_tensors.len(),
274            "There are {} op to init but got {} tensors",
275            states_to_init.len(),
276            state_init_tensors.len()
277        );
278        let mut iterator = state_init_tensors.iter().cloned();
279        for state in states_to_init {
280            state.load_from(&mut self.turn_state, &mut iterator)?;
281        }
282        Ok(())
283    }
284
285    fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
286        for state in self
287            .op_states
288            .iter_mut()
289            .filter_map(Option::as_mut)
290            .filter(|s| s.init_tensor_fact().is_some())
291        {
292            state.resolve_symbols(&mut self.turn_state)?;
293        }
294        Ok(())
295    }
296
297    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
298        self.run_plan_with_eval(inputs, self::eval)
299    }
300
301    pub fn exec(&mut self) -> TractResult<()> {
302        self.exec_plan_with_eval(self::eval)
303    }
304
305    pub fn run_plan_with_eval<Eval, E>(
306        &mut self,
307        inputs: TVec<TValue>,
308        eval: Eval,
309    ) -> TractResult<TVec<TValue>>
310    where
311        Eval: for<'a, 'b, 'c> FnMut(
312            &'a mut TurnState,
313            Option<&'b mut (dyn OpState + 'static)>,
314            &'c Node<F, O>,
315            TVec<TValue>,
316        ) -> Result<TVec<TValue>, E>,
317        E: Into<anyhow::Error> + Send + Sync + 'static,
318    {
319        self.set_inputs(inputs)?;
320        self.resolve_symbols_with_states()?;
321        self.exec_plan_with_eval(eval)?;
322        let outputs = self.outputs()?;
323        self.reset_turn()?;
324        Ok(outputs)
325    }
326
327    pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
328    where
329        Eval: for<'a, 'b, 'c> FnMut(
330            &'a mut TurnState,
331            Option<&'b mut (dyn OpState + 'static)>,
332            &'c Node<F, O>,
333            TVec<TValue>,
334        ) -> Result<TVec<TValue>, E>,
335        E: Into<anyhow::Error> + Send + Sync + 'static,
336    {
337        if let Some(executor) = self.plan().executor.as_ref() {
338            tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
339                self.do_exec_plan_with_eval(eval)
340            })
341        } else {
342            self.do_exec_plan_with_eval(eval)
343        }
344    }
345
346    fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
347    where
348        Eval: for<'a, 'b, 'c> FnMut(
349            &'a mut TurnState,
350            Option<&'b mut (dyn OpState + 'static)>,
351            &'c Node<F, O>,
352            TVec<TValue>,
353        ) -> Result<TVec<TValue>, E>,
354        E: Into<anyhow::Error> + Send + Sync + 'static,
355    {
356        {
357            self.ready_turn();
358            self.plan
359                .session_handler
360                .as_ref()
361                .map(|it| it.before_plan_eval(&mut self.turn_state))
362                .transpose()?;
363
364            for (step, n) in self.plan.order.iter().enumerate() {
365                let node = self.plan.model.node(*n);
366                trace!("Running step {step}, node {node}");
367                let mut inputs: TVec<TValue> = tvec![];
368                for i in &node.inputs {
369                    trace!("  use input {i:?}");
370                    let prec_node = self.plan.model.node(i.node);
371                    let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
372                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
373                    })?;
374                    inputs.push(prec[i.slot].clone())
375                }
376
377                for flush in &self.plan.flush_lists[step] {
378                    trace!("  Ran {} can now flush {}", node, self.plan.model.node(*flush));
379                    self.turn_state.values[*flush] = None;
380                }
381
382                if cfg!(debug_assertions) {
383                    let facts = self.plan.model.node_input_facts(node.id)?;
384                    if facts.len() != inputs.len() {
385                        bail!(
386                            "Evaluating {}: expected {} inputs, got {}",
387                            node,
388                            facts.len(),
389                            inputs.len()
390                        );
391                    }
392                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
393                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
394                            bail!(
395                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
396                                node,
397                                ix,
398                                f,
399                                v
400                            );
401                        }
402                    }
403                }
404
405                let vs = eval(
406                    &mut self.turn_state,
407                    self.op_states[node.id].as_deref_mut(),
408                    node,
409                    inputs,
410                )
411                .map_err(|e| e.into())?;
412
413                if self.plan.has_unresolved_symbols {
414                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
415                        if let Ok(f) = o.fact.to_typed_fact() {
416                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
417                                Self::resolve(
418                                    &mut self.turn_state,
419                                    dim_abstract,
420                                    *dim_concrete as i64,
421                                )?;
422                            }
423                        }
424                    }
425                }
426                if cfg!(debug_assertions) {
427                    let facts = self.plan.model.node_output_facts(node.id)?;
428                    if facts.len() != vs.len() {
429                        bail!(
430                            "Evaluating {}: expected {} outputs, got {}",
431                            node,
432                            facts.len(),
433                            vs.len()
434                        );
435                    }
436                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
437                        if node.outputs[ix].successors.len() == 0 {
438                            continue;
439                        }
440                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
441                            bail!(
442                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
443                                node,
444                                ix,
445                                f,
446                                v
447                            );
448                        }
449                    }
450                }
451
452                self.turn_state.values[node.id] = Some(vs);
453            }
454            self.plan
455                .session_handler
456                .as_ref()
457                .map(|it| it.after_plan_eval(&mut self.turn_state))
458                .transpose()?;
459        }
460        Ok(())
461    }
462
463    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
464        ensure!(
465            inputs.len() == self.model().inputs.len(),
466            "Wrong number of inputs for model. Expected {} got {}",
467            self.model().inputs.len(),
468            inputs.len()
469        );
470
471        for (ix, t) in inputs.into_iter().enumerate() {
472            self.set_input(ix, t)?
473        }
474        Ok(())
475    }
476
477    fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
478        let expected = expression.eval(&state.resolved_symbols);
479        if let Ok(x) = expected.to_i64() {
480            if x != provided {
481                bail!(
482                    "Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})"
483                )
484            }
485        }
486        if expected.symbols().len() == 1 {
487            let sym = expected.symbols().into_iter().next().unwrap();
488            if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
489                debug!("Determined symbol {sym}={v}");
490                state.resolved_symbols.set(&sym, v.to_i64().unwrap());
491            }
492            if state.scenario.is_none() {
493                let scope = sym
494                    .scope()
495                    .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
496                state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
497            }
498        }
499        Ok(())
500    }
501
502    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
503        let outlet: OutletId = *self
504            .model()
505            .input_outlets()?
506            .get(input)
507            .with_context(|| format!("Invalid input id for model ({input})."))?;
508        if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
509            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
510                Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
511            }
512        }
513        let fact = self.plan.model.outlet_fact(outlet)?;
514        ensure!(
515            fact.matches(&t, Some(&self.turn_state.resolved_symbols))
516                .with_context(|| format!("Setting input {input}"))?,
517            "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
518        );
519        self.ready_turn();
520        self.turn_state.values[outlet.node] = Some(tvec!(t));
521        Ok(())
522    }
523
524    pub fn output(&self, id: usize) -> TractResult<&TValue> {
525        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
526            format!(
527                "Required output {}, only have {}",
528                id,
529                self.model().output_outlets().unwrap().len()
530            )
531        })?;
532        let value: &TValue = self
533            .turn_state
534            .values
535            .get(outlet.node)
536            .context("node id for output beyond node values array")?
537            .as_ref()
538            .context("node is not an output")?
539            .get(outlet.slot)
540            .context("slot id too high")?;
541        Ok(value)
542    }
543
544    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
545        let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
546        let mut v = tvec![];
547        for o in plan.outputs.iter() {
548            let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
549                format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
550            })?;
551            v.push(vs[o.slot].clone())
552        }
553        Ok(v)
554    }
555
556    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
557        self.turn_state.values[id] = Some(values);
558        Ok(())
559    }
560
561    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
562        self.set_values(id, tvec!(value))
563    }
564
565    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
566        let SimpleState { plan, turn_state, .. } = self;
567        let nodes = plan.model.nodes();
568        let node = &nodes[node];
569        let mut inputs: TVec<TValue> = tvec![];
570        for i in &node.inputs {
571            let prec_node = &nodes[i.node];
572            let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
573                format_err!("Computing {}, precursor {} not done.", node, prec_node)
574            })?;
575            inputs.push(prec[i.slot].clone())
576        }
577        Ok(inputs)
578    }
579
580    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
581        let inputs = self.prepare_inputs(node)?;
582        self.compute_one_with_inputs(node, inputs)
583    }
584
585    pub fn compute_one_with_inputs(
586        &mut self,
587        node: usize,
588        inputs: TVec<TValue>,
589    ) -> TractResult<()> {
590        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
591        let nodes = plan.model.nodes();
592        let node = &nodes[node];
593        let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
594        turn_state.values[node.id] = Some(vs);
595        Ok(())
596    }
597
598    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
599        let values = {
600            #[allow(clippy::needless_collect)] // clippy bug ?
601            let precs: Vec<usize> =
602                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
603            for i in precs.into_iter() {
604                if self.turn_state.values[i].is_none() {
605                    let _ = self.compute_recursively(i)?;
606                }
607            }
608            let mut inputs: TVec<TValue> = tvec![];
609            {
610                let node = &self.model().nodes()[node];
611                for i in &node.inputs {
612                    inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
613                }
614            }
615            let &mut Self {
616                op_states: ref mut states,
617                turn_state: ref mut session_state,
618                ref plan,
619                ..
620            } = self;
621            eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
622        };
623        self.turn_state.values[node] = Some(values);
624        Ok(self.turn_state.values[node].as_ref().unwrap())
625    }
626
627    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
628        let id = self.model().node_by_name(name)?.id;
629        Self::take(self, id)
630    }
631
632    pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
633        Ok(self.turn_state.values[id]
634            .take()
635            .ok_or_else(|| format_err!("Node is not computed"))?
636            .into_iter()
637            .map(|v| v.into_tensor())
638            .collect())
639    }
640
641    pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
642        &self.plan
643    }
644
645    pub fn model(&self) -> &Graph<F, O> {
646        &self.plan.model
647    }
648
649    pub fn freeze(&self) -> FrozenSimpleState<F, O> {
650        FrozenSimpleState {
651            plan: self.plan.clone(),
652            resolved_symbols: self.turn_state.resolved_symbols.clone(),
653            scenario: self.turn_state.scenario,
654            states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
655            values: self
656                .turn_state
657                .values
658                .iter()
659                .enumerate()
660                .map(|(ix, t)| {
661                    if self.model().nodes[ix].op_is::<Const>() {
662                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
663                    } else {
664                        None
665                    }
666                })
667                .collect(),
668        }
669    }
670}
671
672pub fn eval<F, O>(
673    session_state: &mut TurnState,
674    mut state: Option<&mut (dyn OpState + 'static)>,
675    node: &Node<F, O>,
676    input: TVec<TValue>,
677) -> TractResult<TVec<TValue>>
678where
679    F: Fact + Clone + 'static,
680    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
681{
682    // eprint!("{node} {input:?}");
683    #[allow(clippy::let_and_return)]
684    let r = match state {
685        Some(ref mut state) => state.eval(session_state, node.op(), input),
686        None => node.op().eval_with_session(node.id, session_state, input),
687    }
688    .with_context(|| format!("Evaluating {node}"));
689    // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
690    r
691}
692
693#[derive(Clone, Debug)]
694pub struct FrozenSimpleState<F, O>
695where
696    F: Fact + Clone + 'static,
697    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
698{
699    plan: Arc<SimplePlan<F, O>>,
700    pub resolved_symbols: SymbolValues,
701    pub scenario: Option<usize>,
702    pub states: Vec<Option<Box<dyn FrozenOpState>>>,
703    pub values: Vec<Option<TVec<Tensor>>>,
704}
705
706impl<F, O> FrozenSimpleState<F, O>
707where
708    F: Fact + Clone + 'static,
709    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
710{
711    pub fn unfreeze(&self) -> SimpleState<F, O> {
712        SimpleState {
713            plan: self.plan.clone(),
714            turn_state: TurnState {
715                resolved_symbols: self.resolved_symbols.clone(),
716                scenario: self.scenario,
717                cached_mmm_scratch_space: None.into(),
718                scratch_extensions: anymap3::Map::new(),
719                values: self
720                    .values
721                    .iter()
722                    .map(|t| {
723                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
724                    })
725                    .collect(),
726            },
727            op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
728        }
729    }
730}
731
732#[cfg(test)]
733mod test {
734    use super::*;
735    fn is_send<T: Send>() {}
736    fn is_sync<T: Sync>() {}
737
738    #[test]
739    fn type_model_is_sync() {
740        is_sync::<TypedModel>();
741    }
742
743    #[test]
744    fn type_model_is_send() {
745        is_send::<TypedModel>();
746    }
747
748    #[test]
749    fn type_plan_is_send() {
750        is_send::<TypedSimplePlan>();
751    }
752
753    #[test]
754    fn type_plan_is_sync() {
755        is_sync::<TypedSimplePlan>();
756    }
757
758    #[test]
759    fn frozen_type_state_is_send() {
760        is_send::<TypedFrozenSimpleState>();
761    }
762}