Skip to main content

tract_core/
plan.rs

1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4use std::marker::PhantomData;
5
6use multithread::Executor;
7use tract_data::itertools::Itertools;
8
9use crate::internal::*;
10use crate::model::{Fact, Graph, OutletId};
11use crate::ops::konst::Const;
12use crate::ops::FrozenOpState;
13
14use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
15
16#[derive(Clone, Debug, Default)]
17pub struct PlanOptions {
18    /// Use the simple ordering instead of the newer memory friendly one
19    pub skip_order_opt_ram: bool,
20
21    /// Override default global executor
22    pub executor: Option<Executor>,
23}
24
25pub struct SessionState {
26    pub inputs: HashMap<usize, TValue>,
27    pub resolved_symbols: SymbolValues,
28    pub scenario: Option<usize>,
29    pub tensors: HashMap<String, Tensor>,
30    pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
31    pub scratch_extensions: anymap3::Map,
32}
33
34impl Default for SessionState {
35    fn default() -> Self {
36        SessionState {
37            inputs: HashMap::default(),
38            resolved_symbols: SymbolValues::default(),
39            tensors: HashMap::default(),
40            scenario: None,
41            cached_mmm_scratch_space: None.into(),
42            scratch_extensions: anymap3::Map::new(),
43        }
44    }
45}
46
47impl Clone for SessionState {
48    fn clone(&self) -> Self {
49        SessionState {
50            inputs: self.inputs.clone(),
51            resolved_symbols: self.resolved_symbols.clone(),
52            tensors: self.tensors.clone(),
53            scenario: self.scenario,
54            cached_mmm_scratch_space: None.into(),
55            scratch_extensions: anymap3::Map::new(),
56        }
57    }
58}
59
60pub trait SessionStateHandler: Send + Sync + Debug {
61    fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()>;
62    fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()>;
63}
64
65impl Debug for SessionState {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        write!(f, "SessionState({:?})", self.resolved_symbols)
68    }
69}
70
71#[derive(Debug, Clone)]
72pub struct SimplePlan<F, O, M>
73where
74    F: Fact + Clone + 'static,
75    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
76    M: Borrow<Graph<F, O>>,
77{
78    model: M,
79    outputs: Vec<OutletId>,
80    order: Vec<usize>,
81    flush_lists: Vec<TVec<usize>>,
82    has_unresolved_symbols: bool,
83    executor: Option<Executor>,
84    session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
85    _casper: PhantomData<(F, O)>,
86}
87
88impl<F, O, M> SimplePlan<F, O, M>
89where
90    F: Fact + Clone + 'static,
91    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
92    M: Borrow<Graph<F, O>>,
93{
94    /// This contructor returns a plan that will compute all the model default outputs in one pass.
95    pub fn new(model: M) -> TractResult<SimplePlan<F, O, M>> {
96        let outputs = model.borrow().output_outlets()?.to_vec();
97        Self::build(model, &outputs, &[], &PlanOptions::default())
98    }
99
100    /// This contructor returns a plan that will compute all the model default outputs in one pass.
101    pub fn new_with_options(model: M, options: &PlanOptions) -> TractResult<SimplePlan<F, O, M>> {
102        let outputs = model.borrow().output_outlets()?.to_vec();
103        Self::build(model, &outputs, &[], options)
104    }
105
106    /// This contructor returns a plan that will compute the specified output.
107    #[deprecated]
108    pub fn new_for_output(model: M, output: OutletId) -> TractResult<SimplePlan<F, O, M>> {
109        Self::build(model, &[output], &[], &PlanOptions::default())
110    }
111
112    /// This contructor returns a plan that will compute all specified outputs in one pass.
113    #[deprecated]
114    pub fn new_for_outputs(model: M, outputs: &[OutletId]) -> TractResult<SimplePlan<F, O, M>> {
115        Self::build(model, outputs, &[], &PlanOptions::default())
116    }
117
118    pub fn with_session_handler<H: SessionStateHandler + 'static>(
119        mut self,
120        session_handler: H,
121    ) -> Self {
122        self.session_handler = Some(Arc::new(session_handler));
123        self
124    }
125
126    #[deprecated]
127    pub fn new_for_outputs_and_deps(
128        model: M,
129        outputs: &[OutletId],
130        deps: &[(usize, usize)],
131    ) -> TractResult<SimplePlan<F, O, M>> {
132        Self::build(model, outputs, deps, &PlanOptions::default())
133    }
134
135    pub fn build(
136        model: M,
137        outputs: &[OutletId],
138        deps: &[(usize, usize)],
139        options: &PlanOptions,
140    ) -> TractResult<SimplePlan<F, O, M>> {
141        let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
142        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
143        let mut order = if options.skip_order_opt_ram {
144            eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?
145        } else {
146            eval_order_opt_ram_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?
147        };
148        order.retain(|node| !model.borrow().node(*node).op_is::<Const>());
149        let flush_lists =
150            build_flush_list(model.borrow(), &order, outputs, |n| !n.op_is::<Const>());
151
152        #[allow(clippy::mutable_key_type)]
153        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
154        for node in &model.borrow().nodes {
155            for output in &node.outputs {
156                if let Ok(fact) = output.fact.to_typed_fact() {
157                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
158                }
159            }
160        }
161        Ok(SimplePlan {
162            model,
163            order,
164            flush_lists,
165            outputs: outputs.to_vec(),
166            has_unresolved_symbols: !symbols.is_empty(),
167            _casper: PhantomData,
168            executor: options.executor.clone(),
169            session_handler: None,
170        })
171    }
172
173    pub fn order_without_consts(&self) -> &[usize] {
174        &self.order
175    }
176
177    pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
178        let mut state = SimpleState::new(self)?;
179        state.run(inputs)
180    }
181
182    pub fn model(&self) -> &Graph<F, O> {
183        self.model.borrow()
184    }
185}
186
187#[derive(Clone, Debug)]
188pub struct SimpleState<F, O, M, P>
189where
190    F: Fact + Clone + 'static,
191    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
192    M: Borrow<Graph<F, O>>,
193    P: Borrow<SimplePlan<F, O, M>>,
194{
195    plan: P,
196    pub states: Vec<Option<Box<dyn OpState>>>,
197    pub session_state: SessionState,
198    pub values: Vec<Option<TVec<TValue>>>,
199    _phantom: PhantomData<(M, F, O)>,
200}
201
202impl<F, O, M, P> SimpleState<F, O, M, P>
203where
204    F: Fact + Clone + 'static,
205    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
206    M: Borrow<Graph<F, O>>,
207    P: Borrow<SimplePlan<F, O, M>> + Clone,
208{
209    pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
210        let values = vec![None; plan.borrow().model.borrow().nodes().len()];
211        let session = SessionState::default();
212        let model = plan.borrow().model();
213        let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
214        let mut state =
215            SimpleState { plan, states, session_state: session, values, _phantom: PhantomData };
216        state.populate_consts();
217        state.reset_op_states()?;
218        Ok(state)
219    }
220
221    pub fn new_from_inputs(plan: P, inputs: TVec<TValue>) -> TractResult<SimpleState<F, O, M, P>> {
222        let mut state = SimpleState::new(plan)?;
223        state.set_inputs(inputs)?;
224        state.resolve_symbols_with_states()?;
225
226        Ok(state)
227    }
228
229    fn populate_consts(&mut self) {
230        for node in &self.plan.borrow().model().nodes {
231            if let Some(k) = node.op_as::<Const>() {
232                self.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
233            }
234        }
235    }
236
237    /// Reset wires state.
238    pub fn reset_turn(&mut self) -> TractResult<()> {
239        for node in &self.plan.borrow().order {
240            self.values[*node] = None;
241        }
242        self.session_state.resolved_symbols = SymbolValues::default();
243        Ok(())
244    }
245
246    /// Reset op inner state.
247    pub fn reset_op_states(&mut self) -> TractResult<()> {
248        let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
249        for (ix, n) in plan.borrow().model().nodes().iter().enumerate() {
250            states[ix] =
251                if n.op().is_stateless() { None } else { n.op().state(session_state, ix)? };
252        }
253        Ok(())
254    }
255
256    pub fn init_states(&mut self, state_init_tensors: &mut Vec<TValue>) -> TractResult<()> {
257        let states_to_init = self
258            .states
259            .iter_mut()
260            .filter_map(Option::as_mut)
261            .filter(|s| s.init_tensor_fact().is_some())
262            .collect_vec();
263        ensure!(
264            states_to_init.len() == state_init_tensors.len(),
265            "There are {} op to init but got {} tensors",
266            states_to_init.len(),
267            state_init_tensors.len()
268        );
269        for state in states_to_init {
270            state.load_from(&mut self.session_state, state_init_tensors)?;
271        }
272        Ok(())
273    }
274
275    fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
276        for state in self
277            .states
278            .iter_mut()
279            .filter_map(Option::as_mut)
280            .filter(|s| s.init_tensor_fact().is_some())
281        {
282            state.resolve_symbols(&mut self.session_state)?;
283        }
284        Ok(())
285    }
286
287    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
288        self.run_plan_with_eval(inputs, self::eval)
289    }
290
291    pub fn exec(&mut self) -> TractResult<()> {
292        self.exec_plan_with_eval(self::eval)
293    }
294
295    pub fn run_plan_with_eval<Eval, E>(
296        &mut self,
297        inputs: TVec<TValue>,
298        eval: Eval,
299    ) -> TractResult<TVec<TValue>>
300    where
301        Eval: for<'a, 'b, 'c> FnMut(
302            &'a mut SessionState,
303            Option<&'b mut (dyn OpState + 'static)>,
304            &'c Node<F, O>,
305            TVec<TValue>,
306        ) -> Result<TVec<TValue>, E>,
307        E: Into<anyhow::Error> + Send + Sync + 'static,
308    {
309        self.set_inputs(inputs)?;
310        self.resolve_symbols_with_states()?;
311        self.exec_plan_with_eval(eval)?;
312        let outputs = self.outputs()?;
313        self.reset_turn()?;
314        Ok(outputs)
315    }
316
317    pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
318    where
319        Eval: for<'a, 'b, 'c> FnMut(
320            &'a mut SessionState,
321            Option<&'b mut (dyn OpState + 'static)>,
322            &'c Node<F, O>,
323            TVec<TValue>,
324        ) -> Result<TVec<TValue>, E>,
325        E: Into<anyhow::Error> + Send + Sync + 'static,
326    {
327        if let Some(executor) = self.plan().executor.as_ref() {
328            tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
329                self.do_exec_plan_with_eval(eval)
330            })
331        } else {
332            self.do_exec_plan_with_eval(eval)
333        }
334    }
335
336    fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
337    where
338        Eval: for<'a, 'b, 'c> FnMut(
339            &'a mut SessionState,
340            Option<&'b mut (dyn OpState + 'static)>,
341            &'c Node<F, O>,
342            TVec<TValue>,
343        ) -> Result<TVec<TValue>, E>,
344        E: Into<anyhow::Error> + Send + Sync + 'static,
345    {
346        {
347            let plan = self.plan.borrow();
348            let model = plan.model.borrow();
349            plan.session_handler
350                .as_ref()
351                .map(|it| it.before_plan_eval(&mut self.session_state))
352                .transpose()?;
353
354            for (step, n) in plan.order.iter().enumerate() {
355                let node = model.node(*n);
356                trace!("Running step {step}, node {node}");
357                let mut inputs: TVec<TValue> = tvec![];
358                for i in &node.inputs {
359                    trace!("  use input {i:?}");
360                    let prec_node = model.node(i.node);
361                    let prec = self.values[i.node].as_ref().ok_or_else(|| {
362                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
363                    })?;
364                    inputs.push(prec[i.slot].clone())
365                }
366
367                for flush in &plan.flush_lists[step] {
368                    trace!("  Ran {} can now flush {}", node, model.node(*flush));
369                    self.values[*flush] = None;
370                }
371
372                if cfg!(debug_assertions) {
373                    let facts = model.node_input_facts(node.id)?;
374                    if facts.len() != inputs.len() {
375                        bail!(
376                            "Evaluating {}: expected {} inputs, got {}",
377                            node,
378                            facts.len(),
379                            inputs.len()
380                        );
381                    }
382                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
383                        if !f.matches(v, Some(&self.session_state.resolved_symbols))? {
384                            bail!(
385                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
386                                node,
387                                ix,
388                                f,
389                                v
390                            );
391                        }
392                    }
393                }
394
395                let vs = eval(
396                    &mut self.session_state,
397                    self.states[node.id].as_deref_mut(),
398                    node,
399                    inputs,
400                )
401                .map_err(|e| e.into())?;
402
403                if plan.has_unresolved_symbols {
404                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
405                        if let Ok(f) = o.fact.to_typed_fact() {
406                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
407                                Self::resolve(
408                                    &mut self.session_state,
409                                    dim_abstract,
410                                    *dim_concrete as i64,
411                                )?;
412                            }
413                        }
414                    }
415                }
416                if cfg!(debug_assertions) {
417                    let facts = model.node_output_facts(node.id)?;
418                    if facts.len() != vs.len() {
419                        bail!(
420                            "Evaluating {}: expected {} outputs, got {}",
421                            node,
422                            facts.len(),
423                            vs.len()
424                        );
425                    }
426                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
427                        if node.outputs[ix].successors.len() == 0 {
428                            continue;
429                        }
430                        if !f.matches(v, Some(&self.session_state.resolved_symbols))? {
431                            bail!(
432                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
433                                node,
434                                ix,
435                                f,
436                                v
437                            );
438                        }
439                    }
440                }
441
442                self.values[node.id] = Some(vs);
443            }
444            plan.session_handler
445                .as_ref()
446                .map(|it| it.after_plan_eval(&mut self.session_state))
447                .transpose()?;
448        }
449        Ok(())
450    }
451
452    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
453        ensure!(
454            inputs.len() == self.model().inputs.len(),
455            "Wrong number of inputs for model. Expected {} got {}",
456            self.model().inputs.len(),
457            inputs.len()
458        );
459
460        for (ix, t) in inputs.into_iter().enumerate() {
461            self.set_input(ix, t)?
462        }
463        Ok(())
464    }
465
466    fn resolve(state: &mut SessionState, expression: &TDim, provided: i64) -> TractResult<()> {
467        let expected = expression.eval(&state.resolved_symbols);
468        if let Ok(x) = expected.to_i64() {
469            if x != provided {
470                bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
471            }
472        }
473        if expected.symbols().len() == 1 {
474            let sym = expected.symbols().into_iter().next().unwrap();
475            if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
476                debug!("Determined symbol {sym}={v}");
477                state.resolved_symbols.set(&sym, v.to_i64().unwrap());
478            }
479            if state.scenario.is_none() {
480                let scope = sym
481                    .scope()
482                    .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
483                state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
484            }
485        }
486        Ok(())
487    }
488
489    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
490        let outlet: OutletId = *self
491            .model()
492            .input_outlets()?
493            .get(input)
494            .with_context(|| format!("Invalid input id for model ({input})."))?;
495        if let Ok(fact) = self.plan.borrow().model().outlet_fact(outlet)?.to_typed_fact() {
496            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
497                Self::resolve(&mut self.session_state, expected, *provided as i64)?;
498            }
499        }
500        let fact = self.plan.borrow().model().outlet_fact(outlet)?;
501        ensure!(
502            fact.matches(&t, Some(&self.session_state.resolved_symbols))
503            .with_context(|| format!("Setting input {input}"))?,
504            "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
505            );
506        self.session_state.inputs.insert(outlet.node, t);
507        Ok(())
508    }
509
510    pub fn output(&self, id: usize) -> TractResult<&TValue> {
511        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
512            format!(
513                "Required output {}, only have {}",
514                id,
515                self.model().output_outlets().unwrap().len()
516            )
517        })?;
518        let value: &TValue = self
519            .values
520            .get(outlet.node)
521            .context("node id for output beyond node values array")?
522            .as_ref()
523            .context("node is not an output")?
524            .get(outlet.slot)
525            .context("slot id too high")?;
526        Ok(value)
527    }
528
529    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
530        let &mut SimpleState { ref plan, ref mut values, .. } = self;
531        let mut v = tvec![];
532        for o in plan.borrow().outputs.iter() {
533            let vs = values[o.node].as_mut().ok_or_else(|| {
534                format_err!(
535                    "Outputs of {:?} are not computed",
536                    &plan.borrow().model().nodes()[o.node]
537                )
538            })?;
539            v.push(vs[o.slot].clone())
540        }
541        Ok(v)
542    }
543
544    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
545        self.values[id] = Some(values);
546        Ok(())
547    }
548
549    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
550        self.set_values(id, tvec!(value))
551    }
552
553    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
554        let SimpleState { plan, values, .. } = self;
555        let plan = plan.borrow();
556        let nodes = plan.model().nodes();
557        let node = &nodes[node];
558        let mut inputs: TVec<TValue> = tvec![];
559        for i in &node.inputs {
560            let prec_node = &nodes[i.node];
561            let prec = values[i.node].as_ref().ok_or_else(|| {
562                format_err!("Computing {}, precursor {} not done.", node, prec_node)
563            })?;
564            inputs.push(prec[i.slot].clone())
565        }
566        Ok(inputs)
567    }
568
569    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
570        let inputs = self.prepare_inputs(node)?;
571        self.compute_one_with_inputs(node, inputs)
572    }
573
574    pub fn compute_one_with_inputs(
575        &mut self,
576        node: usize,
577        inputs: TVec<TValue>,
578    ) -> TractResult<()> {
579        let &mut SimpleState {
580            ref plan,
581            ref mut session_state,
582            ref mut values,
583            ref mut states,
584            ..
585        } = self;
586        let plan = plan.borrow();
587        let nodes = plan.model().nodes();
588        let node = &nodes[node];
589        let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)?;
590        values[node.id] = Some(vs);
591        Ok(())
592    }
593
594    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
595        let values = {
596            #[allow(clippy::needless_collect)] // clippy bug ?
597            let precs: Vec<usize> =
598                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
599            for i in precs.into_iter() {
600                if self.values[i].is_none() {
601                    let _ = self.compute_recursively(i)?;
602                }
603            }
604            let mut inputs: TVec<TValue> = tvec![];
605            {
606                let node = &self.model().nodes()[node];
607                for i in &node.inputs {
608                    inputs.push(self.values[i.node].as_ref().unwrap()[i.slot].clone())
609                }
610            }
611            let &mut Self { ref mut states, ref mut session_state, ref plan, .. } = self;
612            eval(
613                session_state,
614                states[node].as_deref_mut(),
615                &plan.borrow().model().nodes[node],
616                inputs,
617            )?
618        };
619        self.values[node] = Some(values);
620        Ok(self.values[node].as_ref().unwrap())
621    }
622
623    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
624        let id = self.model().node_by_name(name)?.id;
625        Self::take(self, id)
626    }
627
628    pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
629        Ok(self.values[id]
630            .take()
631            .ok_or_else(|| format_err!("Node is not computed"))?
632            .into_iter()
633            .map(|v| v.into_tensor())
634            .collect())
635    }
636
637    pub fn plan(&self) -> &SimplePlan<F, O, M> {
638        self.plan.borrow()
639    }
640
641    pub fn model(&self) -> &Graph<F, O> {
642        self.plan().model()
643    }
644
645    pub fn freeze(&self) -> FrozenSimpleState<F, O, M, P> {
646        FrozenSimpleState {
647            plan: self.plan.clone(),
648            inputs: self
649                .session_state
650                .inputs
651                .iter()
652                .map(|(ix, t)| (*ix, t.clone().into_tensor()))
653                .collect(),
654            resolved_symbols: self.session_state.resolved_symbols.clone(),
655            scenario: self.session_state.scenario,
656            tensors: self.session_state.tensors.clone(),
657            states: self.states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
658            values: self
659                .values
660                .iter()
661                .enumerate()
662                .map(|(ix, t)| {
663                    if self.model().nodes[ix].op_is::<Const>() {
664                        t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
665                    } else {
666                        None
667                    }
668                })
669                .collect(),
670            _phantom: PhantomData,
671        }
672    }
673}
674
675pub fn eval<F, O>(
676    session_state: &mut SessionState,
677    mut state: Option<&mut (dyn OpState + 'static)>,
678    node: &Node<F, O>,
679    input: TVec<TValue>,
680) -> TractResult<TVec<TValue>>
681where
682    F: Fact + Clone + 'static,
683    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
684{
685    // eprint!("{node} {input:?}");
686    #[allow(clippy::let_and_return)]
687    let r = match state {
688        Some(ref mut state) => state.eval(session_state, node.op(), input),
689        None => node.op().eval_with_session(node.id, session_state, input),
690    }
691    .with_context(|| format!("Evaluating {node}"));
692    // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?);
693    r
694}
695
696#[derive(Clone, Debug)]
697pub struct FrozenSimpleState<F, O, M, P>
698where
699    F: Fact + Clone + 'static,
700    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
701    M: Borrow<Graph<F, O>>,
702    P: Borrow<SimplePlan<F, O, M>> + Clone,
703{
704    plan: P,
705    pub inputs: HashMap<usize, Tensor>,
706    pub resolved_symbols: SymbolValues,
707    pub scenario: Option<usize>,
708    pub tensors: HashMap<String, Tensor>,
709    pub states: Vec<Option<Box<dyn FrozenOpState>>>,
710    pub values: Vec<Option<TVec<Tensor>>>,
711    _phantom: PhantomData<(M, F, O)>,
712}
713
714impl<F, O, M, P> FrozenSimpleState<F, O, M, P>
715where
716    F: Fact + Clone + 'static,
717    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
718    M: Borrow<Graph<F, O>>,
719    P: Borrow<SimplePlan<F, O, M>> + Clone,
720{
721    pub fn unfreeze(&self) -> SimpleState<F, O, M, P> {
722        let mut state = SimpleState {
723            plan: self.plan.clone(),
724            session_state: SessionState {
725                inputs: self.inputs.iter().map(|(ix, t)| (*ix, t.clone().into_tvalue())).collect(),
726                resolved_symbols: self.resolved_symbols.clone(),
727                scenario: self.scenario,
728                tensors: self.tensors.clone(),
729                cached_mmm_scratch_space: None.into(),
730                scratch_extensions: anymap3::Map::new(),
731            },
732            states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
733            values: self
734                .values
735                .iter()
736                .map(|t| t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect()))
737                .collect(),
738            _phantom: PhantomData,
739        };
740        state.populate_consts();
741        state
742    }
743}
744
745#[cfg(test)]
746mod test {
747    use super::*;
748    fn is_send<T: Send>() {}
749    fn is_sync<T: Sync>() {}
750
751    #[test]
752    fn type_model_is_sync() {
753        is_sync::<TypedModel>();
754    }
755
756    #[test]
757    fn type_model_is_send() {
758        is_send::<TypedModel>();
759    }
760
761    #[test]
762    fn type_plan_is_send() {
763        is_send::<TypedSimplePlan<TypedModel>>();
764    }
765
766    #[test]
767    fn type_plan_is_sync() {
768        is_sync::<TypedSimplePlan<TypedModel>>();
769    }
770
771    #[test]
772    fn frozen_type_state_is_send() {
773        is_send::<TypedFrozenSimpleState<TypedModel, TypedSimplePlan<TypedModel>>>();
774    }
775}