tract_core/
plan.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4use std::marker::PhantomData;
5
6use multithread::Executor;
7
8use crate::internal::*;
9use crate::model::{Fact, Graph, OutletId};
10use crate::ops::konst::Const;
11use crate::ops::FrozenOpState;
12
13use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
14
15#[derive(Clone, Debug, Default)]
16pub struct PlanOptions {
17    /// Use the simple ordering instead of the newer memory friendly one
18    pub skip_order_opt_ram: bool,
19
20    /// Override default global executor
21    pub executor: Option<Executor>,
22}
23
24pub struct SessionState {
25    pub inputs: HashMap<usize, TValue>,
26    pub resolved_symbols: SymbolValues,
27    pub scenario: Option<usize>,
28    pub tensors: HashMap<String, Tensor>,
29    pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
30    pub scratch_extensions: anymap3::Map,
31}
32
33impl Default for SessionState {
34    fn default() -> Self {
35        SessionState {
36            inputs: HashMap::default(),
37            resolved_symbols: SymbolValues::default(),
38            tensors: HashMap::default(),
39            scenario: None,
40            cached_mmm_scratch_space: None.into(),
41            scratch_extensions: anymap3::Map::new(),
42        }
43    }
44}
45
46impl Clone for SessionState {
47    fn clone(&self) -> Self {
48        SessionState {
49            inputs: self.inputs.clone(),
50            resolved_symbols: self.resolved_symbols.clone(),
51            tensors: self.tensors.clone(),
52            scenario: self.scenario,
53            cached_mmm_scratch_space: None.into(),
54            scratch_extensions: anymap3::Map::new(),
55        }
56    }
57}
58
59pub trait SessionStateHandler: Send + Sync + Debug {
60    fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()>;
61    fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()>;
62}
63
64impl Debug for SessionState {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(f, "SessionState({:?})", self.resolved_symbols)
67    }
68}
69
70#[derive(Debug, Clone)]
71pub struct SimplePlan<F, O, M>
72where
73    F: Fact + Clone + 'static,
74    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
75    M: Borrow<Graph<F, O>>,
76{
77    model: M,
78    outputs: Vec<OutletId>,
79    order: Vec<usize>,
80    flush_lists: Vec<TVec<usize>>,
81    has_unresolved_symbols: bool,
82    executor: Option<Executor>,
83    session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
84    _casper: PhantomData<(F, O)>,
85}
86
87impl<F, O, M> SimplePlan<F, O, M>
88where
89    F: Fact + Clone + 'static,
90    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
91    M: Borrow<Graph<F, O>>,
92{
93    /// This contructor returns a plan that will compute all the model default outputs in one pass.
94    pub fn new(model: M) -> TractResult<SimplePlan<F, O, M>> {
95        let outputs = model.borrow().output_outlets()?.to_vec();
96        Self::build(model, &outputs, &[], &PlanOptions::default())
97    }
98
99    /// This contructor returns a plan that will compute all the model default outputs in one pass.
100    pub fn new_with_options(model: M, options: &PlanOptions) -> TractResult<SimplePlan<F, O, M>> {
101        let outputs = model.borrow().output_outlets()?.to_vec();
102        Self::build(model, &outputs, &[], options)
103    }
104
105    /// This contructor returns a plan that will compute the specified output.
106    #[deprecated]
107    pub fn new_for_output(model: M, output: OutletId) -> TractResult<SimplePlan<F, O, M>> {
108        Self::build(model, &[output], &[], &PlanOptions::default())
109    }
110
111    /// This contructor returns a plan that will compute all specified outputs in one pass.
112    #[deprecated]
113    pub fn new_for_outputs(model: M, outputs: &[OutletId]) -> TractResult<SimplePlan<F, O, M>> {
114        Self::build(model, outputs, &[], &PlanOptions::default())
115    }
116
117    pub fn with_session_handler<H: SessionStateHandler + 'static>(
118        mut self,
119        session_handler: H,
120    ) -> Self {
121        self.session_handler = Some(Arc::new(session_handler));
122        self
123    }
124
125    #[deprecated]
126    pub fn new_for_outputs_and_deps(
127        model: M,
128        outputs: &[OutletId],
129        deps: &[(usize, usize)],
130    ) -> TractResult<SimplePlan<F, O, M>> {
131        Self::build(model, outputs, deps, &PlanOptions::default())
132    }
133
134    pub fn build(
135        model: M,
136        outputs: &[OutletId],
137        deps: &[(usize, usize)],
138        options: &PlanOptions,
139    ) -> TractResult<SimplePlan<F, O, M>> {
140        let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
141        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
142        let mut order = if options.skip_order_opt_ram {
143            eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?
144        } else {
145            eval_order_opt_ram_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?
146        };
147        order.retain(|node| !model.borrow().node(*node).op_is::<Const>());
148        let flush_lists =
149            build_flush_list(model.borrow(), &order, outputs, |n| !n.op_is::<Const>());
150
151        #[allow(clippy::mutable_key_type)]
152        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
153        for node in &model.borrow().nodes {
154            for output in &node.outputs {
155                if let Ok(fact) = output.fact.to_typed_fact() {
156                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
157                }
158            }
159        }
160        Ok(SimplePlan {
161            model,
162            order,
163            flush_lists,
164            outputs: outputs.to_vec(),
165            has_unresolved_symbols: !symbols.is_empty(),
166            _casper: PhantomData,
167            executor: options.executor.clone(),
168            session_handler: None,
169        })
170    }
171
172    pub fn order_without_consts(&self) -> &[usize] {
173        &self.order
174    }
175
176    pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
177        let mut state = SimpleState::new(self)?;
178        state.run(inputs)
179    }
180
181    pub fn model(&self) -> &Graph<F, O> {
182        self.model.borrow()
183    }
184}
185
186#[derive(Clone, Debug)]
187pub struct SimpleState<F, O, M, P>
188where
189    F: Fact + Clone + 'static,
190    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
191    M: Borrow<Graph<F, O>>,
192    P: Borrow<SimplePlan<F, O, M>>,
193{
194    plan: P,
195    pub states: Vec<Option<Box<dyn OpState>>>,
196    pub session_state: SessionState,
197    pub values: Vec<Option<TVec<TValue>>>,
198    _phantom: PhantomData<(M, F, O)>,
199}
200
201impl<F, O, M, P> SimpleState<F, O, M, P>
202where
203    F: Fact + Clone + 'static,
204    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
205    M: Borrow<Graph<F, O>>,
206    P: Borrow<SimplePlan<F, O, M>> + Clone,
207{
208    pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
209        let values = vec![None; plan.borrow().model.borrow().nodes().len()];
210        let session = SessionState::default();
211        let model = plan.borrow().model();
212        let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
213        let mut state =
214            SimpleState { plan, states, session_state: session, values, _phantom: PhantomData };
215        state.populate_consts();
216        state.reset_op_states()?;
217        Ok(state)
218    }
219
220    pub fn new_from_inputs(plan: P, inputs: TVec<TValue>) -> TractResult<SimpleState<F, O, M, P>> {
221        let mut state = SimpleState::new(plan)?;
222        state.set_inputs(inputs)?;
223
224        Ok(state)
225    }
226
227    fn populate_consts(&mut self) {
228        for node in &self.plan.borrow().model().nodes {
229            if let Some(k) = node.op_as::<Const>() {
230                self.values[node.id] = Some(tvec!(k.0.clone().into_tvalue()));
231            }
232        }
233    }
234
235    /// Reset wires state.
236    pub fn reset_turn(&mut self) -> TractResult<()> {
237        for node in &self.plan.borrow().order {
238            self.values[*node] = None;
239        }
240        self.session_state.resolved_symbols = SymbolValues::default();
241        Ok(())
242    }
243
244    /// Reset op inner state.
245    pub fn reset_op_states(&mut self) -> TractResult<()> {
246        let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
247        for (ix, n) in plan.borrow().model().nodes().iter().enumerate() {
248            states[ix] =
249                if n.op().is_stateless() { None } else { n.op().state(session_state, ix)? };
250        }
251        Ok(())
252    }
253
254    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
255        self.run_plan_with_eval(inputs, self::eval)
256    }
257
258    pub fn exec(&mut self) -> TractResult<()> {
259        self.exec_plan_with_eval(self::eval)
260    }
261
262    pub fn run_plan_with_eval<Eval, E>(
263        &mut self,
264        inputs: TVec<TValue>,
265        eval: Eval,
266    ) -> TractResult<TVec<TValue>>
267    where
268        Eval: for<'a, 'b, 'c> FnMut(
269            &'a mut SessionState,
270            Option<&'b mut (dyn OpState + 'static)>,
271            &'c Node<F, O>,
272            TVec<TValue>,
273        ) -> Result<TVec<TValue>, E>,
274        E: Into<anyhow::Error> + Send + Sync + 'static,
275    {
276        self.set_inputs(inputs)?;
277        self.exec_plan_with_eval(eval)?;
278        let outputs = self.outputs()?;
279        self.reset_turn()?;
280        Ok(outputs)
281    }
282
283    pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
284    where
285        Eval: for<'a, 'b, 'c> FnMut(
286            &'a mut SessionState,
287            Option<&'b mut (dyn OpState + 'static)>,
288            &'c Node<F, O>,
289            TVec<TValue>,
290        ) -> Result<TVec<TValue>, E>,
291        E: Into<anyhow::Error> + Send + Sync + 'static,
292    {
293        if let Some(executor) = self.plan().executor.as_ref() {
294            tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
295                self.do_exec_plan_with_eval(eval)
296            })
297        } else {
298            self.do_exec_plan_with_eval(eval)
299        }
300    }
301
302    fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
303    where
304        Eval: for<'a, 'b, 'c> FnMut(
305            &'a mut SessionState,
306            Option<&'b mut (dyn OpState + 'static)>,
307            &'c Node<F, O>,
308            TVec<TValue>,
309        ) -> Result<TVec<TValue>, E>,
310        E: Into<anyhow::Error> + Send + Sync + 'static,
311    {
312        {
313            let plan = self.plan.borrow();
314            let model = plan.model.borrow();
315            plan.session_handler
316                .as_ref()
317                .map(|it| it.before_plan_eval(&mut self.session_state))
318                .transpose()?;
319
320            for (step, n) in plan.order.iter().enumerate() {
321                let node = model.node(*n);
322                trace!("Running step {}, node {}", step, node);
323                let mut inputs: TVec<TValue> = tvec![];
324                for i in &node.inputs {
325                    trace!("  use input {:?}", i);
326                    let prec_node = model.node(i.node);
327                    let prec = self.values[i.node].as_ref().ok_or_else(|| {
328                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
329                    })?;
330                    inputs.push(prec[i.slot].clone())
331                }
332
333                for flush in &plan.flush_lists[step] {
334                    trace!("  Ran {} can now flush {}", node, model.node(*flush));
335                    self.values[*flush] = None;
336                }
337
338                if cfg!(debug_assertions) {
339                    let facts = model.node_input_facts(node.id)?;
340                    if facts.len() != inputs.len() {
341                        bail!(
342                            "Evaluating {}: expected {} inputs, got {}",
343                            node,
344                            facts.len(),
345                            inputs.len()
346                        );
347                    }
348                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
349                        if !f.matches(v, Some(&self.session_state.resolved_symbols))? {
350                            bail!(
351                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
352                                node,
353                                ix,
354                                f,
355                                v
356                            );
357                        }
358                    }
359                }
360
361                let vs = eval(
362                    &mut self.session_state,
363                    self.states[node.id].as_deref_mut(),
364                    node,
365                    inputs,
366                )
367                .map_err(|e| e.into())?;
368
369                if plan.has_unresolved_symbols {
370                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
371                        if let Ok(f) = o.fact.to_typed_fact() {
372                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
373                                Self::resolve(
374                                    &mut self.session_state,
375                                    dim_abstract,
376                                    *dim_concrete as i64,
377                                )?;
378                            }
379                        }
380                    }
381                }
382                if cfg!(debug_assertions) {
383                    let facts = model.node_output_facts(node.id)?;
384                    if facts.len() != vs.len() {
385                        bail!(
386                            "Evaluating {}: expected {} outputs, got {}",
387                            node,
388                            facts.len(),
389                            vs.len()
390                        );
391                    }
392                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
393                        if node.outputs[ix].successors.len() == 0 {
394                            continue;
395                        }
396                        if !f.matches(v, Some(&self.session_state.resolved_symbols))? {
397                            bail!(
398                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
399                                node,
400                                ix,
401                                f,
402                                v
403                            );
404                        }
405                    }
406                }
407
408                self.values[node.id] = Some(vs);
409            }
410            plan.session_handler
411                .as_ref()
412                .map(|it| it.after_plan_eval(&mut self.session_state))
413                .transpose()?;
414        }
415        Ok(())
416    }
417
418    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
419        ensure!(
420            inputs.len() == self.model().inputs.len(),
421            "Wrong number of inputs for model. Expected {} got {}",
422            self.model().inputs.len(),
423            inputs.len()
424        );
425        for (ix, t) in inputs.into_iter().enumerate() {
426            self.set_input(ix, t)?
427        }
428        Ok(())
429    }
430
431    fn resolve(state: &mut SessionState, expression: &TDim, provided: i64) -> TractResult<()> {
432        let expected = expression.eval(&state.resolved_symbols);
433        if let Ok(x) = expected.to_i64() {
434            if x != provided {
435                bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
436            }
437        }
438        if expected.symbols().len() == 1 {
439            let sym = expected.symbols().into_iter().next().unwrap();
440            if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
441                debug!("Determined symbol {sym}={v}");
442                state.resolved_symbols.set(&sym, v.to_i64().unwrap());
443            }
444            if state.scenario.is_none() {
445                state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
446            }
447        }
448        Ok(())
449    }
450
451    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
452        let outlet: OutletId = *self
453            .model()
454            .input_outlets()?
455            .get(input)
456            .with_context(|| format!("Invalid input id for model ({input})."))?;
457        if let Ok(fact) = self.plan.borrow().model().outlet_fact(outlet)?.to_typed_fact() {
458            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
459                Self::resolve(&mut self.session_state, expected, *provided as i64)?;
460            }
461        }
462        let fact = self.plan.borrow().model().outlet_fact(outlet)?;
463        ensure!(
464            fact.matches(&t, Some(&self.session_state.resolved_symbols))
465            .with_context(|| format!("Setting input {input}"))?,
466            "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
467            );
468        self.session_state.inputs.insert(outlet.node, t);
469        Ok(())
470    }
471
472    pub fn output(&self, id: usize) -> TractResult<&TValue> {
473        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
474            format!(
475                "Required output {}, only have {}",
476                id,
477                self.model().output_outlets().unwrap().len()
478            )
479        })?;
480        let value: &TValue = self
481            .values
482            .get(outlet.node)
483            .context("node id for output beyond node values array")?
484            .as_ref()
485            .context("node is not an output")?
486            .get(outlet.slot)
487            .context("slot id too high")?;
488        Ok(value)
489    }
490
491    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
492        let SimpleState { ref plan, ref mut values, .. } = self;
493        let mut v = tvec![];
494        for o in plan.borrow().outputs.iter() {
495            let vs = values[o.node].as_mut().ok_or_else(|| {
496                format_err!(
497                    "Outputs of {:?} are not computed",
498                    &plan.borrow().model().nodes()[o.node]
499                )
500            })?;
501            v.push(vs[o.slot].clone())
502        }
503        Ok(v)
504    }
505
506    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
507        self.values[id] = Some(values);
508        Ok(())
509    }
510
511    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
512        self.set_values(id, tvec!(value))
513    }
514
515    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
516        let SimpleState { ref plan, ref values, .. } = self;
517        let plan = plan.borrow();
518        let nodes = plan.model().nodes();
519        let node = &nodes[node];
520        let mut inputs: TVec<TValue> = tvec![];
521        for i in &node.inputs {
522            let prec_node = &nodes[i.node];
523            let prec = values[i.node].as_ref().ok_or_else(|| {
524                format_err!("Computing {}, precursor {} not done.", node, prec_node)
525            })?;
526            inputs.push(prec[i.slot].clone())
527        }
528        Ok(inputs)
529    }
530
531    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
532        let inputs = self.prepare_inputs(node)?;
533        self.compute_one_with_inputs(node, inputs)
534    }
535
536    pub fn compute_one_with_inputs(
537        &mut self,
538        node: usize,
539        inputs: TVec<TValue>,
540    ) -> TractResult<()> {
541        let SimpleState { ref plan, ref mut session_state, ref mut values, ref mut states, .. } =
542            self;
543        let plan = plan.borrow();
544        let nodes = plan.model().nodes();
545        let node = &nodes[node];
546        let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)?;
547        values[node.id] = Some(vs);
548        Ok(())
549    }
550
551    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
552        let values = {
553            #[allow(clippy::needless_collect)] // clippy bug ?
554            let precs: Vec<usize> =
555                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
556            for i in precs.into_iter() {
557                if self.values[i].is_none() {
558                    let _ = self.compute_recursively(i)?;
559                }
560            }
561            let mut inputs: TVec<TValue> = tvec![];
562            {
563                let node = &self.model().nodes()[node];
564                for i in &node.inputs {
565                    inputs.push(self.values[i.node].as_ref().unwrap()[i.slot].clone())
566                }
567            }
568            let Self { ref mut states, ref mut session_state, ref plan, .. } = self;
569            eval(
570                session_state,
571                states[node].as_deref_mut(),
572                &plan.borrow().model().nodes[node],
573                inputs,
574            )?
575        };
576        self.values[node] = Some(values);
577        Ok(self.values[node].as_ref().unwrap())
578    }
579
580    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
581        let id = self.model().node_by_name(name)?.id;
582        Self::take(self, id)
583    }
584
585    pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
586        Ok(self.values[id]
587            .take()
588            .ok_or_else(|| format_err!("Node is not computed"))?
589            .into_iter()
590            .map(|v| v.into_tensor())
591            .collect())
592    }
593
594    pub fn plan(&self) -> &SimplePlan<F, O, M> {
595        self.plan.borrow()
596    }
597
598    pub fn model(&self) -> &Graph<F, O> {
599        self.plan().model()
600    }
601
602    pub fn freeze(&self) -> FrozenSimpleState<F, O, M, P> {
603        FrozenSimpleState {
604            plan: self.plan.clone(),
605            inputs: self
606                .session_state
607                .inputs
608                .iter()
609                .map(|(ix, t)| (*ix, t.clone().into_tensor()))
610                .collect(),
611            resolved_symbols: self.session_state.resolved_symbols.clone(),
612            scenario: self.session_state.scenario,
613            tensors: self.session_state.tensors.clone(),
614            states: self.states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
615            values: self
616                .values
617                .iter()
618                .enumerate()
619                .map(|(ix, t)| {
620                    if self.model().nodes[ix].op_is::<Const>() {
621                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
622                    } else {
623                        None
624                    }
625                })
626                .collect(),
627            _phantom: PhantomData,
628        }
629    }
630}
631
632pub fn eval<F, O>(
633    session_state: &mut SessionState,
634    mut state: Option<&mut (dyn OpState + 'static)>,
635    node: &Node<F, O>,
636    input: TVec<TValue>,
637) -> TractResult<TVec<TValue>>
638where
639    F: Fact + Clone + 'static,
640    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
641{
642    // eprint!("{node} {input:?}");
643    let r = match state {
644        Some(ref mut state) => state.eval(session_state, node.op(), input),
645        None => node.op().eval_with_session(session_state, input),
646    }
647    .with_context(|| format!("Evaluating {node}"));
648    // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
649    r
650}
651
652#[derive(Clone, Debug)]
653pub struct FrozenSimpleState<F, O, M, P>
654where
655    F: Fact + Clone + 'static,
656    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
657    M: Borrow<Graph<F, O>>,
658    P: Borrow<SimplePlan<F, O, M>> + Clone,
659{
660    plan: P,
661    pub inputs: HashMap<usize, Tensor>,
662    pub resolved_symbols: SymbolValues,
663    pub scenario: Option<usize>,
664    pub tensors: HashMap<String, Tensor>,
665    pub states: Vec<Option<Box<dyn FrozenOpState>>>,
666    pub values: Vec<Option<TVec<Tensor>>>,
667    _phantom: PhantomData<(M, F, O)>,
668}
669
670impl<F, O, M, P> FrozenSimpleState<F, O, M, P>
671where
672    F: Fact + Clone + 'static,
673    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
674    M: Borrow<Graph<F, O>>,
675    P: Borrow<SimplePlan<F, O, M>> + Clone,
676{
677    pub fn unfreeze(&self) -> SimpleState<F, O, M, P> {
678        let mut state = SimpleState {
679            plan: self.plan.clone(),
680            session_state: SessionState {
681                inputs: self.inputs.iter().map(|(ix, t)| (*ix, t.clone().into_tvalue())).collect(),
682                resolved_symbols: self.resolved_symbols.clone(),
683                scenario: self.scenario,
684                tensors: self.tensors.clone(),
685                cached_mmm_scratch_space: None.into(),
686                scratch_extensions: anymap3::Map::new(),
687            },
688            states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
689            values: self
690                .values
691                .iter()
692                .map(|t| t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect()))
693                .collect(),
694            _phantom: PhantomData,
695        };
696        state.populate_consts();
697        state
698    }
699}
700
701#[cfg(test)]
702mod test {
703    use super::*;
704    fn is_send<T: Send>() {}
705    fn is_sync<T: Sync>() {}
706
707    #[test]
708    fn type_model_is_sync() {
709        is_sync::<TypedModel>();
710    }
711
712    #[test]
713    fn type_model_is_send() {
714        is_send::<TypedModel>();
715    }
716
717    #[test]
718    fn type_plan_is_send() {
719        is_send::<TypedSimplePlan<TypedModel>>();
720    }
721
722    #[test]
723    fn type_plan_is_sync() {
724        is_sync::<TypedSimplePlan<TypedModel>>();
725    }
726
727    #[test]
728    fn frozen_type_state_is_send() {
729        is_send::<TypedFrozenSimpleState<TypedModel, TypedSimplePlan<TypedModel>>>();
730    }
731}