Skip to main content

tract_core/
plan.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6
7use crate::internal::*;
8use crate::model::{Fact, Graph, OutletId};
9use crate::ops::FrozenOpState;
10use crate::ops::konst::Const;
11use crate::runtime::RunOptions;
12
13use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
14
15pub struct TurnState {
16    pub resolved_symbols: SymbolValues,
17    pub scenario: Option<usize>,
18    pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
19    pub scratch_extensions: anymap3::Map,
20    pub values: Vec<Option<TVec<TValue>>>,
21}
22
23impl Default for TurnState {
24    fn default() -> Self {
25        TurnState {
26            resolved_symbols: SymbolValues::default(),
27            scenario: None,
28            cached_mmm_scratch_space: None.into(),
29            scratch_extensions: anymap3::Map::new(),
30            values: vec![],
31        }
32    }
33}
34
35impl Clone for TurnState {
36    fn clone(&self) -> Self {
37        TurnState {
38            resolved_symbols: self.resolved_symbols.clone(),
39            scenario: self.scenario,
40            cached_mmm_scratch_space: None.into(),
41            scratch_extensions: anymap3::Map::new(),
42            values: vec![],
43        }
44    }
45}
46
47pub trait SessionStateHandler: Send + Sync + Debug {
48    fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
49    fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50}
51
52impl Debug for TurnState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "SessionState({:?})", self.resolved_symbols)
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct SimplePlan<F, O>
60where
61    F: Fact + Clone + 'static,
62    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
63{
64    pub(crate) model: Arc<Graph<F, O>>,
65    outputs: Vec<OutletId>,
66    order: Vec<usize>,
67    flush_lists: Vec<TVec<usize>>,
68    has_unresolved_symbols: bool,
69    executor: Option<Executor>,
70    session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
71}
72
73impl<F, O> SimplePlan<F, O>
74where
75    F: Fact + Clone + 'static,
76    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
77{
78    /// This contructor returns a plan that will compute all the model default outputs in one pass.
79    pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
80        let model = model.into();
81        Self::build(model, &RunOptions::default()).map(Arc::new)
82    }
83
84    /// This contructor returns a plan that will compute all the model default outputs in one pass.
85    pub fn new_with_options(
86        model: impl Into<Arc<Graph<F, O>>>,
87        options: &RunOptions,
88    ) -> TractResult<Arc<SimplePlan<F, O>>> {
89        let model = model.into();
90        Self::build(model, options).map(Arc::new)
91    }
92
93    /// This contructor returns a plan that will compute the specified output.
94    #[deprecated]
95    pub fn new_for_output(
96        model: Graph<F, O>,
97        output: OutletId,
98    ) -> TractResult<Arc<SimplePlan<F, O>>> {
99        #[allow(deprecated)]
100        Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
101            .map(Arc::new)
102    }
103
104    /// This contructor returns a plan that will compute all specified outputs in one pass.
105    #[deprecated]
106    pub fn new_for_outputs(
107        model: impl Into<Arc<Graph<F, O>>>,
108        outputs: &[OutletId],
109    ) -> TractResult<Arc<SimplePlan<F, O>>> {
110        #[allow(deprecated)]
111        Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
112    }
113
114    pub fn with_session_handler<H: SessionStateHandler + 'static>(
115        mut self,
116        session_handler: H,
117    ) -> Self {
118        self.session_handler = Some(Arc::new(session_handler));
119        self
120    }
121
122    #[deprecated]
123    pub fn new_for_outputs_and_deps(
124        model: impl Into<Arc<Graph<F, O>>>,
125        outputs: &[OutletId],
126        deps: &[(usize, usize)],
127    ) -> TractResult<Arc<SimplePlan<F, O>>> {
128        #[allow(deprecated)]
129        Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
130            .map(Arc::new)
131    }
132
133    pub fn build(
134        model: impl Into<Arc<Graph<F, O>>>,
135        options: &RunOptions,
136    ) -> TractResult<SimplePlan<F, O>> {
137        let model = model.into();
138        let outputs = model.outputs.clone();
139        #[allow(deprecated)]
140        Self::build_with_outputs_and_deps(model, &outputs, &[], options)
141    }
142
143    #[deprecated]
144    pub fn build_with_outputs_and_deps(
145        model: impl Into<Arc<Graph<F, O>>>,
146        outputs: &[OutletId],
147        deps: &[(usize, usize)],
148        options: &RunOptions,
149    ) -> TractResult<SimplePlan<F, O>> {
150        let model = model.into();
151        let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
152        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
153        let mut order = if options.skip_order_opt_ram {
154            eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
155        } else {
156            eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
157        };
158        order.retain(|node| !model.node(*node).op_is::<Const>());
159        let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
160
161        #[allow(clippy::mutable_key_type)]
162        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
163        for node in &model.nodes {
164            for output in &node.outputs {
165                if let Ok(fact) = output.fact.to_typed_fact() {
166                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
167                }
168            }
169        }
170        Ok(SimplePlan {
171            model,
172            order,
173            flush_lists,
174            outputs: outputs.to_vec(),
175            has_unresolved_symbols: !symbols.is_empty(),
176            executor: options.executor.clone(),
177            session_handler: None,
178        })
179    }
180
181    pub fn order_without_consts(&self) -> &[usize] {
182        &self.order
183    }
184
185    pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
186        let mut state = self.spawn()?;
187        state.run(inputs)
188    }
189
190    pub fn model(&self) -> &Graph<F, O> {
191        self.model.borrow()
192    }
193
194    pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
195        SimpleState::new(self)
196    }
197}
198
199#[derive(Clone, Debug)]
200pub struct SimpleState<F, O>
201where
202    F: Fact + Clone + 'static,
203    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
204{
205    pub(crate) plan: Arc<SimplePlan<F, O>>,
206    pub op_states: Vec<Option<Box<dyn OpState>>>,
207    pub turn_state: TurnState,
208}
209
210impl<F, O> SimpleState<F, O>
211where
212    F: Fact + Clone + 'static,
213    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
214{
215    pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
216        let plan = Arc::clone(plan);
217        let turn = TurnState::default();
218        let model = plan.model();
219        let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
220        let mut state = SimpleState { plan, op_states: states, turn_state: turn };
221        state.reset_op_states()?;
222        Ok(state)
223    }
224
225    pub fn new_from_inputs(
226        plan: &Arc<SimplePlan<F, O>>,
227        inputs: TVec<TValue>,
228    ) -> TractResult<SimpleState<F, O>> {
229        let mut state = SimpleState::new(plan)?;
230        state.set_inputs(inputs)?;
231        state.resolve_symbols_with_states()?;
232
233        Ok(state)
234    }
235
236    fn ready_turn(&mut self) {
237        if self.turn_state.values.len() == 0 {
238            self.turn_state.values = vec![None; self.plan.model.nodes().len()];
239            for node in &self.plan.model.nodes {
240                if let Some(k) = node.op_as::<Const>() {
241                    self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
242                }
243            }
244        }
245    }
246    /// Reset wires state.
247    pub fn reset_turn(&mut self) -> TractResult<()> {
248        for node in &self.plan.order {
249            self.turn_state.values[*node] = None;
250        }
251        self.turn_state.resolved_symbols = SymbolValues::default();
252        Ok(())
253    }
254
255    /// Reset op inner state.
256    fn reset_op_states(&mut self) -> TractResult<()> {
257        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
258        for (ix, n) in plan.model.nodes.iter().enumerate() {
259            states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
260        }
261        Ok(())
262    }
263
264    fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
265        for state in self
266            .op_states
267            .iter_mut()
268            .filter_map(Option::as_mut)
269            .filter(|s| s.init_tensor_fact().is_some())
270        {
271            state.resolve_symbols(&mut self.turn_state)?;
272        }
273        Ok(())
274    }
275
276    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
277        self.run_plan_with_eval(inputs, self::eval)
278    }
279
280    pub fn exec(&mut self) -> TractResult<()> {
281        self.exec_plan_with_eval(self::eval)
282    }
283
284    pub fn run_plan_with_eval<Eval, E>(
285        &mut self,
286        inputs: TVec<TValue>,
287        eval: Eval,
288    ) -> TractResult<TVec<TValue>>
289    where
290        Eval: for<'a, 'b, 'c> FnMut(
291            &'a mut TurnState,
292            Option<&'b mut (dyn OpState + 'static)>,
293            &'c Node<F, O>,
294            TVec<TValue>,
295        ) -> Result<TVec<TValue>, E>,
296        E: Into<anyhow::Error> + Send + Sync + 'static,
297    {
298        self.set_inputs(inputs)?;
299        self.resolve_symbols_with_states()?;
300        self.exec_plan_with_eval(eval)?;
301        let outputs = self.outputs()?;
302        self.reset_turn()?;
303        Ok(outputs)
304    }
305
306    pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
307    where
308        Eval: for<'a, 'b, 'c> FnMut(
309            &'a mut TurnState,
310            Option<&'b mut (dyn OpState + 'static)>,
311            &'c Node<F, O>,
312            TVec<TValue>,
313        ) -> Result<TVec<TValue>, E>,
314        E: Into<anyhow::Error> + Send + Sync + 'static,
315    {
316        if let Some(executor) = self.plan().executor.as_ref() {
317            tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
318                self.do_exec_plan_with_eval(eval)
319            })
320        } else {
321            self.do_exec_plan_with_eval(eval)
322        }
323    }
324
325    fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
326    where
327        Eval: for<'a, 'b, 'c> FnMut(
328            &'a mut TurnState,
329            Option<&'b mut (dyn OpState + 'static)>,
330            &'c Node<F, O>,
331            TVec<TValue>,
332        ) -> Result<TVec<TValue>, E>,
333        E: Into<anyhow::Error> + Send + Sync + 'static,
334    {
335        {
336            self.ready_turn();
337            self.plan
338                .session_handler
339                .as_ref()
340                .map(|it| it.before_plan_eval(&mut self.turn_state))
341                .transpose()?;
342
343            for (step, n) in self.plan.order.iter().enumerate() {
344                let node = self.plan.model.node(*n);
345                trace!("Running step {step}, node {node}");
346                let mut inputs: TVec<TValue> = tvec![];
347                for i in &node.inputs {
348                    trace!("  use input {i:?}");
349                    let prec_node = self.plan.model.node(i.node);
350                    let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
351                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
352                    })?;
353                    inputs.push(prec[i.slot].clone())
354                }
355
356                for flush in &self.plan.flush_lists[step] {
357                    trace!("  Ran {} can now flush {}", node, self.plan.model.node(*flush));
358                    self.turn_state.values[*flush] = None;
359                }
360
361                if cfg!(debug_assertions) {
362                    let facts = self.plan.model.node_input_facts(node.id)?;
363                    if facts.len() != inputs.len() {
364                        bail!(
365                            "Evaluating {}: expected {} inputs, got {}",
366                            node,
367                            facts.len(),
368                            inputs.len()
369                        );
370                    }
371                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
372                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
373                            bail!(
374                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
375                                node,
376                                ix,
377                                f,
378                                v
379                            );
380                        }
381                    }
382                }
383
384                let vs = eval(
385                    &mut self.turn_state,
386                    self.op_states[node.id].as_deref_mut(),
387                    node,
388                    inputs,
389                )
390                .map_err(|e| e.into())?;
391
392                if self.plan.has_unresolved_symbols {
393                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
394                        if let Ok(f) = o.fact.to_typed_fact() {
395                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
396                                Self::resolve(
397                                    &mut self.turn_state,
398                                    dim_abstract,
399                                    *dim_concrete as i64,
400                                )?;
401                            }
402                        }
403                    }
404                }
405                if cfg!(debug_assertions) {
406                    let facts = self.plan.model.node_output_facts(node.id)?;
407                    if facts.len() != vs.len() {
408                        bail!(
409                            "Evaluating {}: expected {} outputs, got {}",
410                            node,
411                            facts.len(),
412                            vs.len()
413                        );
414                    }
415                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
416                        if node.outputs[ix].successors.len() == 0 {
417                            continue;
418                        }
419                        if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
420                            bail!(
421                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
422                                node,
423                                ix,
424                                f,
425                                v
426                            );
427                        }
428                    }
429                }
430
431                self.turn_state.values[node.id] = Some(vs);
432            }
433            self.plan
434                .session_handler
435                .as_ref()
436                .map(|it| it.after_plan_eval(&mut self.turn_state))
437                .transpose()?;
438        }
439        Ok(())
440    }
441
442    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
443        ensure!(
444            inputs.len() == self.model().inputs.len(),
445            "Wrong number of inputs for model. Expected {} got {}",
446            self.model().inputs.len(),
447            inputs.len()
448        );
449
450        for (ix, t) in inputs.into_iter().enumerate() {
451            self.set_input(ix, t)?
452        }
453        Ok(())
454    }
455
456    fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
457        let expected = expression.eval(&state.resolved_symbols);
458        if let Ok(x) = expected.to_i64() {
459            if x != provided {
460                bail!(
461                    "Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})"
462                )
463            }
464        }
465        if expected.symbols().len() == 1 {
466            let sym = expected.symbols().into_iter().next().unwrap();
467            if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
468                debug!("Determined symbol {sym}={v}");
469                state.resolved_symbols.set(&sym, v.to_i64().unwrap());
470            }
471            if state.scenario.is_none() {
472                let scope = sym
473                    .scope()
474                    .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
475                state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
476            }
477        }
478        Ok(())
479    }
480
481    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
482        let outlet: OutletId = *self
483            .model()
484            .input_outlets()?
485            .get(input)
486            .with_context(|| format!("Invalid input id for model ({input})."))?;
487        if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
488            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
489                Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
490            }
491        }
492        let fact = self.plan.model.outlet_fact(outlet)?;
493        ensure!(
494            fact.matches(&t, Some(&self.turn_state.resolved_symbols))
495                .with_context(|| format!("Setting input {input}"))?,
496            "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
497        );
498        self.ready_turn();
499        self.turn_state.values[outlet.node] = Some(tvec!(t));
500        Ok(())
501    }
502
503    pub fn output(&self, id: usize) -> TractResult<&TValue> {
504        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
505            format!(
506                "Required output {}, only have {}",
507                id,
508                self.model().output_outlets().unwrap().len()
509            )
510        })?;
511        let value: &TValue = self
512            .turn_state
513            .values
514            .get(outlet.node)
515            .context("node id for output beyond node values array")?
516            .as_ref()
517            .context("node is not an output")?
518            .get(outlet.slot)
519            .context("slot id too high")?;
520        Ok(value)
521    }
522
523    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
524        let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
525        let mut v = tvec![];
526        for o in plan.outputs.iter() {
527            let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
528                format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
529            })?;
530            v.push(vs[o.slot].clone())
531        }
532        Ok(v)
533    }
534
535    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
536        self.turn_state.values[id] = Some(values);
537        Ok(())
538    }
539
540    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
541        self.set_values(id, tvec!(value))
542    }
543
544    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
545        let SimpleState { plan, turn_state, .. } = self;
546        let nodes = plan.model.nodes();
547        let node = &nodes[node];
548        let mut inputs: TVec<TValue> = tvec![];
549        for i in &node.inputs {
550            let prec_node = &nodes[i.node];
551            let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
552                format_err!("Computing {}, precursor {} not done.", node, prec_node)
553            })?;
554            inputs.push(prec[i.slot].clone())
555        }
556        Ok(inputs)
557    }
558
559    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
560        let inputs = self.prepare_inputs(node)?;
561        self.compute_one_with_inputs(node, inputs)
562    }
563
564    pub fn compute_one_with_inputs(
565        &mut self,
566        node: usize,
567        inputs: TVec<TValue>,
568    ) -> TractResult<()> {
569        let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
570        let nodes = plan.model.nodes();
571        let node = &nodes[node];
572        let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
573        turn_state.values[node.id] = Some(vs);
574        Ok(())
575    }
576
577    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
578        let values = {
579            #[allow(clippy::needless_collect)] // clippy bug ?
580            let precs: Vec<usize> =
581                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
582            for i in precs.into_iter() {
583                if self.turn_state.values[i].is_none() {
584                    let _ = self.compute_recursively(i)?;
585                }
586            }
587            let mut inputs: TVec<TValue> = tvec![];
588            {
589                let node = &self.model().nodes()[node];
590                for i in &node.inputs {
591                    inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
592                }
593            }
594            let &mut Self {
595                op_states: ref mut states,
596                turn_state: ref mut session_state,
597                ref plan,
598                ..
599            } = self;
600            eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
601        };
602        self.turn_state.values[node] = Some(values);
603        Ok(self.turn_state.values[node].as_ref().unwrap())
604    }
605
606    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
607        let id = self.model().node_by_name(name)?.id;
608        Self::take(self, id)
609    }
610
611    pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
612        Ok(self.turn_state.values[id]
613            .take()
614            .ok_or_else(|| format_err!("Node is not computed"))?
615            .into_iter()
616            .map(|v| v.into_tensor())
617            .collect())
618    }
619
620    pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
621        &self.plan
622    }
623
624    pub fn model(&self) -> &Graph<F, O> {
625        &self.plan.model
626    }
627
628    pub fn freeze(&self) -> FrozenSimpleState<F, O> {
629        FrozenSimpleState {
630            plan: self.plan.clone(),
631            resolved_symbols: self.turn_state.resolved_symbols.clone(),
632            scenario: self.turn_state.scenario,
633            states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
634            values: self
635                .turn_state
636                .values
637                .iter()
638                .enumerate()
639                .map(|(ix, t)| {
640                    if self.model().nodes[ix].op_is::<Const>() {
641                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
642                    } else {
643                        None
644                    }
645                })
646                .collect(),
647        }
648    }
649}
650
651pub fn eval<F, O>(
652    session_state: &mut TurnState,
653    mut state: Option<&mut (dyn OpState + 'static)>,
654    node: &Node<F, O>,
655    input: TVec<TValue>,
656) -> TractResult<TVec<TValue>>
657where
658    F: Fact + Clone + 'static,
659    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
660{
661    // eprint!("{node} {input:?}");
662    #[allow(clippy::let_and_return)]
663    let r = match state {
664        Some(ref mut state) => state.eval(session_state, node.op(), input),
665        None => node.op().eval_with_session(node.id, session_state, input),
666    }
667    .with_context(|| format!("Evaluating {node}"));
668    // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
669    r
670}
671
672#[derive(Clone, Debug)]
673pub struct FrozenSimpleState<F, O>
674where
675    F: Fact + Clone + 'static,
676    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
677{
678    plan: Arc<SimplePlan<F, O>>,
679    pub resolved_symbols: SymbolValues,
680    pub scenario: Option<usize>,
681    pub states: Vec<Option<Box<dyn FrozenOpState>>>,
682    pub values: Vec<Option<TVec<Tensor>>>,
683}
684
685impl<F, O> FrozenSimpleState<F, O>
686where
687    F: Fact + Clone + 'static,
688    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
689{
690    pub fn unfreeze(&self) -> SimpleState<F, O> {
691        SimpleState {
692            plan: self.plan.clone(),
693            turn_state: TurnState {
694                resolved_symbols: self.resolved_symbols.clone(),
695                scenario: self.scenario,
696                cached_mmm_scratch_space: None.into(),
697                scratch_extensions: anymap3::Map::new(),
698                values: self
699                    .values
700                    .iter()
701                    .map(|t| {
702                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
703                    })
704                    .collect(),
705            },
706            op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
707        }
708    }
709}
710
711#[cfg(test)]
712mod test {
713    use super::*;
714    fn is_send<T: Send>() {}
715    fn is_sync<T: Sync>() {}
716
717    #[test]
718    fn type_model_is_sync() {
719        is_sync::<TypedModel>();
720    }
721
722    #[test]
723    fn type_model_is_send() {
724        is_send::<TypedModel>();
725    }
726
727    #[test]
728    fn type_plan_is_send() {
729        is_send::<TypedSimplePlan>();
730    }
731
732    #[test]
733    fn type_plan_is_sync() {
734        is_sync::<TypedSimplePlan>();
735    }
736
737    #[test]
738    fn frozen_type_state_is_send() {
739        is_send::<TypedFrozenSimpleState>();
740    }
741}