Skip to main content

tract_core/
plan.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6
7use crate::internal::*;
8use crate::model::{Fact, Graph, OutletId};
9use crate::ops::FrozenOpState;
10use crate::ops::konst::Const;
11use crate::runtime::RunOptions;
12
13use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
14
15pub struct TurnState {
16    pub resolved_symbols: SymbolValues,
17    pub scenario: Option<usize>,
18    pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
19    pub scratch_extensions: anymap3::Map,
20    pub values: Vec<Option<TVec<TValue>>>,
21}
22
23impl Default for TurnState {
24    fn default() -> Self {
25        TurnState {
26            resolved_symbols: SymbolValues::default(),
27            scenario: None,
28            cached_mmm_scratch_space: None.into(),
29            scratch_extensions: anymap3::Map::new(),
30            values: vec![],
31        }
32    }
33}
34
35impl Clone for TurnState {
36    fn clone(&self) -> Self {
37        TurnState {
38            resolved_symbols: self.resolved_symbols.clone(),
39            scenario: self.scenario,
40            cached_mmm_scratch_space: None.into(),
41            scratch_extensions: anymap3::Map::new(),
42            values: vec![],
43        }
44    }
45}
46
47pub trait SessionStateHandler: Send + Sync + Debug {
48    fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
49    fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50}
51
52impl Debug for TurnState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "SessionState({:?})", self.resolved_symbols)
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct SimplePlan<F, O>
60where
61    F: Fact + Clone + 'static,
62    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
63{
64    pub(crate) model: Arc<Graph<F, O>>,
65    outputs: Vec<OutletId>,
66    order: Vec<usize>,
67    flush_lists: Vec<TVec<usize>>,
68    has_unresolved_symbols: bool,
69    executor: Option<Executor>,
70    session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
71}
72
73impl<F, O> SimplePlan<F, O>
74where
75    F: Fact + Clone + 'static,
76    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
77{
78    /// This contructor returns a plan that will compute all the model default outputs in one pass.
79    pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
80        let model = model.into();
81        Self::build(model, &RunOptions::default()).map(Arc::new)
82    }
83
84    /// This contructor returns a plan that will compute all the model default outputs in one pass.
85    pub fn new_with_options(
86        model: impl Into<Arc<Graph<F, O>>>,
87        options: &RunOptions,
88    ) -> TractResult<Arc<SimplePlan<F, O>>> {
89        let model = model.into();
90        Self::build(model, options).map(Arc::new)
91    }
92
93    /// This contructor returns a plan that will compute the specified output.
94    #[deprecated]
95    pub fn new_for_output(
96        model: Graph<F, O>,
97        output: OutletId,
98    ) -> TractResult<Arc<SimplePlan<F, O>>> {
99        #[allow(deprecated)]
100        Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
101            .map(Arc::new)
102    }
103
104    /// This contructor returns a plan that will compute all specified outputs in one pass.
105    #[deprecated]
106    pub fn new_for_outputs(
107        model: impl Into<Arc<Graph<F, O>>>,
108        outputs: &[OutletId],
109    ) -> TractResult<Arc<SimplePlan<F, O>>> {
110        #[allow(deprecated)]
111        Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
112    }
113
114    pub fn with_session_handler<H: SessionStateHandler + 'static>(
115        mut self,
116        session_handler: H,
117    ) -> Self {
118        self.session_handler = Some(Arc::new(session_handler));
119        self
120    }
121
122    #[deprecated]
123    pub fn new_for_outputs_and_deps(
124        model: impl Into<Arc<Graph<F, O>>>,
125        outputs: &[OutletId],
126        deps: &[(usize, usize)],
127    ) -> TractResult<Arc<SimplePlan<F, O>>> {
128        #[allow(deprecated)]
129        Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
130            .map(Arc::new)
131    }
132
133    pub fn build(
134        model: impl Into<Arc<Graph<F, O>>>,
135        options: &RunOptions,
136    ) -> TractResult<SimplePlan<F, O>> {
137        let model = model.into();
138        let outputs = model.outputs.clone();
139        #[allow(deprecated)]
140        Self::build_with_outputs_and_deps(model, &outputs, &[], options)
141    }
142
143    #[deprecated]
144    pub fn build_with_outputs_and_deps(
145        model: impl Into<Arc<Graph<F, O>>>,
146        outputs: &[OutletId],
147        deps: &[(usize, usize)],
148        options: &RunOptions,
149    ) -> TractResult<SimplePlan<F, O>> {
150        let model = model.into();
151        let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
152        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
153        let mut order = if options.skip_order_opt_ram {
154            eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
155        } else {
156            eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
157        };
158        order.retain(|node| !model.node(*node).op_is::<Const>());
159        let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
160
161        #[allow(clippy::mutable_key_type)]
162        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
163        for node in &model.nodes {
164            for output in &node.outputs {
165                if let Ok(fact) = output.fact.to_typed_fact() {
166                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
167                }
168            }
169        }
170        Ok(SimplePlan {
171            model,
172            order,
173            flush_lists,
174            outputs: outputs.to_vec(),
175            has_unresolved_symbols: !symbols.is_empty(),
176            executor: options.executor.clone(),
177            session_handler: None,
178        })
179    }
180
181    pub fn order_without_consts(&self) -> &[usize] {
182        &self.order
183    }
184
185    pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
186        let mut state = self.spawn()?;
187        state.run(inputs)
188    }
189
190    pub fn model(&self) -> &Graph<F, O> {
191        self.model.borrow()
192    }
193
194    pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
195        SimpleState::new(self)
196    }
197}
198
199#[derive(Clone, Debug)]
200pub struct SimpleState<F, O>
201where
202    F: Fact + Clone + 'static,
203    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
204{
205    pub(crate) plan: Arc<SimplePlan<F, O>>,
206    pub op_states: Vec<Option<Box<dyn OpState>>>,
207    pub turn_state: TurnState,
208}
209
210impl<F, O> SimpleState<F, O>
211where
212    F: Fact + Clone + 'static,
213    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
214{
215    pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
216        let plan = Arc::clone(plan);
217        let turn = TurnState::default();
218        let model = plan.model();
219        let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
220        let mut state = SimpleState { plan, op_states: states, turn_state: turn };
221        state.reset_op_states()?;
222        Ok(state)
223    }
224
225    pub fn new_from_inputs(
226        plan: &Arc<SimplePlan<F, O>>,
227        inputs: TVec<TValue>,
228    ) -> TractResult<SimpleState<F, O>> {
229        let mut state = SimpleState::new(plan)?;
230        state.set_inputs(inputs)?;
231        state.resolve_symbols_with_states()?;
232
233        Ok(state)
234    }
235
236    fn ready_turn(&mut self) {
237        if self.turn_state.values.len() == 0 {
238            self.turn_state.values = vec![None; self.plan.model.nodes().len()];
239            for node in &self.plan.model.nodes {
240                if let Some(k) = node.op_as::<Const>() {
241                    self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
242                }
243            }
244        }
245    }
246    /// Reset wires state.
247    pub fn reset_turn(&mut self) -> TractResult<()> {
248        for node in &self.plan.order {
249            self.turn_state.values[*node] = None;
250        }
251        self.turn_state.resolved_symbols = SymbolValues::default();
252        Ok(())
253    }
254
255    /// Reset op inner state.
256    fn reset_op_states(&mut self) -> TractResult<()> {
257        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
258        for (ix, n) in plan.model.nodes.iter().enumerate() {
259            states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
260        }
261        Ok(())
262    }
263
264    fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
265        for state in self
266            .op_states
267            .iter_mut()
268            .filter_map(Option::as_mut)
269            .filter(|s| s.init_tensor_fact().is_some())
270        {
271            state.resolve_symbols(&mut self.turn_state)?;
272        }
273        Ok(())
274    }
275
276    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
277        self.run_plan_with_eval(inputs, self::eval)
278    }
279
280    pub fn exec(&mut self) -> TractResult<()> {
281        self.exec_plan_with_eval(self::eval)
282    }
283
284    pub fn run_plan_with_eval<Eval, E>(
285        &mut self,
286        inputs: TVec<TValue>,
287        eval: Eval,
288    ) -> TractResult<TVec<TValue>>
289    where
290        Eval: for<'a, 'b, 'c> FnMut(
291            &'a mut TurnState,
292            Option<&'b mut (dyn OpState + 'static)>,
293            &'c Node<F, O>,
294            TVec<TValue>,
295        ) -> Result<TVec<TValue>, E>,
296        E: Into<anyhow::Error> + Send + Sync + 'static,
297    {
298        self.set_inputs(inputs)?;
299        self.resolve_symbols_with_states()?;
300        self.exec_plan_with_eval(eval)?;
301        let outputs = self.outputs()?;
302        self.reset_turn()?;
303        Ok(outputs)
304    }
305
306    pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
307    where
308        Eval: for<'a, 'b, 'c> FnMut(
309            &'a mut TurnState,
310            Option<&'b mut (dyn OpState + 'static)>,
311            &'c Node<F, O>,
312            TVec<TValue>,
313        ) -> Result<TVec<TValue>, E>,
314        E: Into<anyhow::Error> + Send + Sync + 'static,
315    {
316        if let Some(executor) = self.plan().executor.as_ref() {
317            tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
318                self.do_exec_plan_with_eval(eval)
319            })
320        } else {
321            self.do_exec_plan_with_eval(eval)
322        }
323    }
324
325    fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
326    where
327        Eval: for<'a, 'b, 'c> FnMut(
328            &'a mut TurnState,
329            Option<&'b mut (dyn OpState + 'static)>,
330            &'c Node<F, O>,
331            TVec<TValue>,
332        ) -> Result<TVec<TValue>, E>,
333        E: Into<anyhow::Error> + Send + Sync + 'static,
334    {
335        {
336            self.ready_turn();
337            self.plan
338                .session_handler
339                .as_ref()
340                .map(|it| it.before_plan_eval(&mut self.turn_state))
341                .transpose()?;
342
343            for (step, n) in self.plan.order.iter().enumerate() {
344                let node = self.plan.model.node(*n);
345                trace!("Running step {step}, node {node}");
346                let mut inputs: TVec<TValue> = tvec![];
347                for i in &node.inputs {
348                    trace!("  use input {i:?}");
349                    let prec_node = self.plan.model.node(i.node);
350                    let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
351                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
352                    })?;
353                    inputs.push(prec[i.slot].clone())
354                }
355
356                for flush in &self.plan.flush_lists[step] {
357                    trace!("  Ran {} can now flush {}", node, self.plan.model.node(*flush));
358                    self.turn_state.values[*flush] = None;
359                }
360
361                if cfg!(debug_assertions) {
362                    let facts = self.plan.model.node_input_facts(node.id)?;
363                    if facts.len() != inputs.len() {
364                        bail!(
365                            "Evaluating {}: expected {} inputs, got {}",
366                            node,
367                            facts.len(),
368                            inputs.len()
369                        );
370                    }
371                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
372                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
373                            bail!(
374                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
375                                node,
376                                ix,
377                                f,
378                                v
379                            );
380                        }
381                    }
382                }
383
384                let vs = eval(
385                    &mut self.turn_state,
386                    self.op_states[node.id].as_deref_mut(),
387                    node,
388                    inputs,
389                )
390                .map_err(|e| e.into())?;
391
392                if self.plan.has_unresolved_symbols {
393                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
394                        if let Ok(f) = o.fact.to_typed_fact() {
395                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
396                                Self::resolve(
397                                    &mut self.turn_state,
398                                    dim_abstract,
399                                    *dim_concrete as i64,
400                                )?;
401                            }
402                        }
403                    }
404                }
405                if cfg!(debug_assertions) {
406                    let facts = self.plan.model.node_output_facts(node.id)?;
407                    if facts.len() != vs.len() {
408                        bail!(
409                            "Evaluating {}: expected {} outputs, got {}",
410                            node,
411                            facts.len(),
412                            vs.len()
413                        );
414                    }
415                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
416                        if node.outputs[ix].successors.len() == 0 {
417                            continue;
418                        }
419                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
420                            bail!(
421                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
422                                node,
423                                ix,
424                                f,
425                                v
426                            );
427                        }
428                    }
429                }
430
431                self.turn_state.values[node.id] = Some(vs);
432            }
433            self.plan
434                .session_handler
435                .as_ref()
436                .map(|it| it.after_plan_eval(&mut self.turn_state))
437                .transpose()?;
438        }
439        Ok(())
440    }
441
442    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
443        ensure!(
444            inputs.len() == self.model().inputs.len(),
445            "Wrong number of inputs for model. Expected {} got {}",
446            self.model().inputs.len(),
447            inputs.len()
448        );
449
450        for (ix, t) in inputs.into_iter().enumerate() {
451            self.set_input(ix, t)?
452        }
453        Ok(())
454    }
455
456    fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
457        let expected = expression.eval(&state.resolved_symbols);
458        if let Ok(x) = expected.to_i64()
459            && x != provided
460        {
461            bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
462        }
463        if expected.symbols().len() == 1 {
464            let sym = expected.symbols().into_iter().next().unwrap();
465            if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
466                debug!("Determined symbol {sym}={v}");
467                state.resolved_symbols.set(&sym, v.to_i64().unwrap());
468            }
469            if state.scenario.is_none() {
470                let scope = sym
471                    .scope()
472                    .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
473                state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
474            }
475        }
476        Ok(())
477    }
478
479    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
480        let outlet: OutletId = *self
481            .model()
482            .input_outlets()?
483            .get(input)
484            .with_context(|| format!("Invalid input id for model ({input})."))?;
485        if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
486            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
487                Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
488            }
489        }
490        let fact = self.plan.model.outlet_fact(outlet)?;
491        ensure!(
492            fact.matches(&t, Some(&self.turn_state.resolved_symbols))
493                .with_context(|| format!("Setting input {input}"))?,
494            "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
495        );
496        self.ready_turn();
497        self.turn_state.values[outlet.node] = Some(tvec!(t));
498        Ok(())
499    }
500
501    pub fn output(&self, id: usize) -> TractResult<&TValue> {
502        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
503            format!(
504                "Required output {}, only have {}",
505                id,
506                self.model().output_outlets().unwrap().len()
507            )
508        })?;
509        let value: &TValue = self
510            .turn_state
511            .values
512            .get(outlet.node)
513            .context("node id for output beyond node values array")?
514            .as_ref()
515            .context("node is not an output")?
516            .get(outlet.slot)
517            .context("slot id too high")?;
518        Ok(value)
519    }
520
521    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
522        let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
523        let mut v = tvec![];
524        for o in plan.outputs.iter() {
525            let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
526                format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
527            })?;
528            v.push(vs[o.slot].clone())
529        }
530        Ok(v)
531    }
532
533    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
534        self.turn_state.values[id] = Some(values);
535        Ok(())
536    }
537
538    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
539        self.set_values(id, tvec!(value))
540    }
541
542    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
543        let SimpleState { plan, turn_state, .. } = self;
544        let nodes = plan.model.nodes();
545        let node = &nodes[node];
546        let mut inputs: TVec<TValue> = tvec![];
547        for i in &node.inputs {
548            let prec_node = &nodes[i.node];
549            let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
550                format_err!("Computing {}, precursor {} not done.", node, prec_node)
551            })?;
552            inputs.push(prec[i.slot].clone())
553        }
554        Ok(inputs)
555    }
556
557    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
558        let inputs = self.prepare_inputs(node)?;
559        self.compute_one_with_inputs(node, inputs)
560    }
561
562    pub fn compute_one_with_inputs(
563        &mut self,
564        node: usize,
565        inputs: TVec<TValue>,
566    ) -> TractResult<()> {
567        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
568        let nodes = plan.model.nodes();
569        let node = &nodes[node];
570        let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
571        turn_state.values[node.id] = Some(vs);
572        Ok(())
573    }
574
575    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
576        let values = {
577            #[allow(clippy::needless_collect)] // clippy bug ?
578            let precs: Vec<usize> =
579                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
580            for i in precs.into_iter() {
581                if self.turn_state.values[i].is_none() {
582                    let _ = self.compute_recursively(i)?;
583                }
584            }
585            let mut inputs: TVec<TValue> = tvec![];
586            {
587                let node = &self.model().nodes()[node];
588                for i in &node.inputs {
589                    inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
590                }
591            }
592            let &mut Self {
593                op_states: ref mut states,
594                turn_state: ref mut session_state,
595                ref plan,
596                ..
597            } = self;
598            eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
599        };
600        self.turn_state.values[node] = Some(values);
601        Ok(self.turn_state.values[node].as_ref().unwrap())
602    }
603
604    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
605        let id = self.model().node_by_name(name)?.id;
606        Self::take(self, id)
607    }
608
609    pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
610        Ok(self.turn_state.values[id]
611            .take()
612            .ok_or_else(|| format_err!("Node is not computed"))?
613            .into_iter()
614            .map(|v| v.into_tensor())
615            .collect())
616    }
617
618    pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
619        &self.plan
620    }
621
622    pub fn model(&self) -> &Graph<F, O> {
623        &self.plan.model
624    }
625
626    pub fn freeze(&self) -> FrozenSimpleState<F, O> {
627        FrozenSimpleState {
628            plan: self.plan.clone(),
629            resolved_symbols: self.turn_state.resolved_symbols.clone(),
630            scenario: self.turn_state.scenario,
631            states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
632            values: self
633                .turn_state
634                .values
635                .iter()
636                .enumerate()
637                .map(|(ix, t)| {
638                    if self.model().nodes[ix].op_is::<Const>() {
639                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
640                    } else {
641                        None
642                    }
643                })
644                .collect(),
645        }
646    }
647
648    pub fn freeze_into(self) -> FrozenSimpleState<F, O> {
649        let plan = self.plan;
650        let model = &plan.model;
651        FrozenSimpleState {
652            resolved_symbols: self.turn_state.resolved_symbols,
653            scenario: self.turn_state.scenario,
654            states: self.op_states.into_iter().map(|s| s.map(|s| s.freeze_into())).collect(),
655            values: self
656                .turn_state
657                .values
658                .into_iter()
659                .enumerate()
660                .map(|(ix, t)| {
661                    if model.nodes[ix].op_is::<Const>() {
662                        t.map(|t| t.into_iter().map(|t| t.into_tensor()).collect())
663                    } else {
664                        None
665                    }
666                })
667                .collect(),
668            plan,
669        }
670    }
671}
672
673pub fn eval<F, O>(
674    session_state: &mut TurnState,
675    mut state: Option<&mut (dyn OpState + 'static)>,
676    node: &Node<F, O>,
677    input: TVec<TValue>,
678) -> TractResult<TVec<TValue>>
679where
680    F: Fact + Clone + 'static,
681    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
682{
683    // eprint!("{node} {input:?}");
684    #[allow(clippy::let_and_return)]
685    let r = match state {
686        Some(ref mut state) => state.eval(session_state, node.op(), input),
687        None => node.op().eval_with_session(node.id, session_state, input),
688    }
689    .with_context(|| format!("Evaluating {node}"));
690    // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
691    r
692}
693
694#[derive(Clone, Debug)]
695pub struct FrozenSimpleState<F, O>
696where
697    F: Fact + Clone + 'static,
698    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
699{
700    plan: Arc<SimplePlan<F, O>>,
701    pub resolved_symbols: SymbolValues,
702    pub scenario: Option<usize>,
703    pub states: Vec<Option<Box<dyn FrozenOpState>>>,
704    pub values: Vec<Option<TVec<Tensor>>>,
705}
706
707impl<F, O> FrozenSimpleState<F, O>
708where
709    F: Fact + Clone + 'static,
710    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
711{
712    pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
713        &self.plan
714    }
715
716    pub fn unfreeze(&self) -> SimpleState<F, O> {
717        SimpleState {
718            plan: self.plan.clone(),
719            turn_state: TurnState {
720                resolved_symbols: self.resolved_symbols.clone(),
721                scenario: self.scenario,
722                cached_mmm_scratch_space: None.into(),
723                scratch_extensions: anymap3::Map::new(),
724                values: self
725                    .values
726                    .iter()
727                    .map(|t| {
728                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
729                    })
730                    .collect(),
731            },
732            op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
733        }
734    }
735}
736
737#[cfg(test)]
738mod test {
739    use super::*;
740    fn is_send<T: Send>() {}
741    fn is_sync<T: Sync>() {}
742
743    #[test]
744    fn type_model_is_sync() {
745        is_sync::<TypedModel>();
746    }
747
748    #[test]
749    fn type_model_is_send() {
750        is_send::<TypedModel>();
751    }
752
753    #[test]
754    fn type_plan_is_send() {
755        is_send::<TypedSimplePlan>();
756    }
757
758    #[test]
759    fn type_plan_is_sync() {
760        is_sync::<TypedSimplePlan>();
761    }
762
763    #[test]
764    fn frozen_type_state_is_send() {
765        is_send::<TypedFrozenSimpleState>();
766    }
767}