1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4use std::marker::PhantomData;
5
6use multithread::Executor;
7use tract_data::itertools::Itertools;
8
9use crate::internal::*;
10use crate::model::{Fact, Graph, OutletId};
11use crate::ops::konst::Const;
12use crate::ops::FrozenOpState;
13
14use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
15
16#[derive(Clone, Debug, Default)]
17pub struct PlanOptions {
18 pub skip_order_opt_ram: bool,
20
21 pub executor: Option<Executor>,
23}
24
25pub struct SessionState {
26 pub inputs: HashMap<usize, TValue>,
27 pub resolved_symbols: SymbolValues,
28 pub scenario: Option<usize>,
29 pub tensors: HashMap<String, Tensor>,
30 pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
31 pub scratch_extensions: anymap3::Map,
32}
33
34impl Default for SessionState {
35 fn default() -> Self {
36 SessionState {
37 inputs: HashMap::default(),
38 resolved_symbols: SymbolValues::default(),
39 tensors: HashMap::default(),
40 scenario: None,
41 cached_mmm_scratch_space: None.into(),
42 scratch_extensions: anymap3::Map::new(),
43 }
44 }
45}
46
47impl Clone for SessionState {
48 fn clone(&self) -> Self {
49 SessionState {
50 inputs: self.inputs.clone(),
51 resolved_symbols: self.resolved_symbols.clone(),
52 tensors: self.tensors.clone(),
53 scenario: self.scenario,
54 cached_mmm_scratch_space: None.into(),
55 scratch_extensions: anymap3::Map::new(),
56 }
57 }
58}
59
60pub trait SessionStateHandler: Send + Sync + Debug {
61 fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()>;
62 fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()>;
63}
64
65impl Debug for SessionState {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 write!(f, "SessionState({:?})", self.resolved_symbols)
68 }
69}
70
71#[derive(Debug, Clone)]
72pub struct SimplePlan<F, O, M>
73where
74 F: Fact + Clone + 'static,
75 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
76 M: Borrow<Graph<F, O>>,
77{
78 model: M,
79 outputs: Vec<OutletId>,
80 order: Vec<usize>,
81 flush_lists: Vec<TVec<usize>>,
82 has_unresolved_symbols: bool,
83 executor: Option<Executor>,
84 session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
85 _casper: PhantomData<(F, O)>,
86}
87
88impl<F, O, M> SimplePlan<F, O, M>
89where
90 F: Fact + Clone + 'static,
91 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
92 M: Borrow<Graph<F, O>>,
93{
94 pub fn new(model: M) -> TractResult<SimplePlan<F, O, M>> {
96 let outputs = model.borrow().output_outlets()?.to_vec();
97 Self::build(model, &outputs, &[], &PlanOptions::default())
98 }
99
100 pub fn new_with_options(model: M, options: &PlanOptions) -> TractResult<SimplePlan<F, O, M>> {
102 let outputs = model.borrow().output_outlets()?.to_vec();
103 Self::build(model, &outputs, &[], options)
104 }
105
106 #[deprecated]
108 pub fn new_for_output(model: M, output: OutletId) -> TractResult<SimplePlan<F, O, M>> {
109 Self::build(model, &[output], &[], &PlanOptions::default())
110 }
111
112 #[deprecated]
114 pub fn new_for_outputs(model: M, outputs: &[OutletId]) -> TractResult<SimplePlan<F, O, M>> {
115 Self::build(model, outputs, &[], &PlanOptions::default())
116 }
117
118 pub fn with_session_handler<H: SessionStateHandler + 'static>(
119 mut self,
120 session_handler: H,
121 ) -> Self {
122 self.session_handler = Some(Arc::new(session_handler));
123 self
124 }
125
126 #[deprecated]
127 pub fn new_for_outputs_and_deps(
128 model: M,
129 outputs: &[OutletId],
130 deps: &[(usize, usize)],
131 ) -> TractResult<SimplePlan<F, O, M>> {
132 Self::build(model, outputs, deps, &PlanOptions::default())
133 }
134
135 pub fn build(
136 model: M,
137 outputs: &[OutletId],
138 deps: &[(usize, usize)],
139 options: &PlanOptions,
140 ) -> TractResult<SimplePlan<F, O, M>> {
141 let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
142 let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
143 let mut order = if options.skip_order_opt_ram {
144 eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?
145 } else {
146 eval_order_opt_ram_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?
147 };
148 order.retain(|node| !model.borrow().node(*node).op_is::<Const>());
149 let flush_lists =
150 build_flush_list(model.borrow(), &order, outputs, |n| !n.op_is::<Const>());
151
152 #[allow(clippy::mutable_key_type)]
153 let mut symbols: std::collections::HashSet<Symbol> = Default::default();
154 for node in &model.borrow().nodes {
155 for output in &node.outputs {
156 if let Ok(fact) = output.fact.to_typed_fact() {
157 symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
158 }
159 }
160 }
161 Ok(SimplePlan {
162 model,
163 order,
164 flush_lists,
165 outputs: outputs.to_vec(),
166 has_unresolved_symbols: !symbols.is_empty(),
167 _casper: PhantomData,
168 executor: options.executor.clone(),
169 session_handler: None,
170 })
171 }
172
173 pub fn order_without_consts(&self) -> &[usize] {
174 &self.order
175 }
176
177 pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
178 let mut state = SimpleState::new(self)?;
179 state.run(inputs)
180 }
181
182 pub fn model(&self) -> &Graph<F, O> {
183 self.model.borrow()
184 }
185}
186
187#[derive(Clone, Debug)]
188pub struct SimpleState<F, O, M, P>
189where
190 F: Fact + Clone + 'static,
191 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
192 M: Borrow<Graph<F, O>>,
193 P: Borrow<SimplePlan<F, O, M>>,
194{
195 plan: P,
196 pub states: Vec<Option<Box<dyn OpState>>>,
197 pub session_state: SessionState,
198 pub values: Vec<Option<TVec<TValue>>>,
199 _phantom: PhantomData<(M, F, O)>,
200}
201
202impl<F, O, M, P> SimpleState<F, O, M, P>
203where
204 F: Fact + Clone + 'static,
205 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
206 M: Borrow<Graph<F, O>>,
207 P: Borrow<SimplePlan<F, O, M>> + Clone,
208{
209 pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
210 let values = vec![None; plan.borrow().model.borrow().nodes().len()];
211 let session = SessionState::default();
212 let model = plan.borrow().model();
213 let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
214 let mut state =
215 SimpleState { plan, states, session_state: session, values, _phantom: PhantomData };
216 state.populate_consts();
217 state.reset_op_states()?;
218 Ok(state)
219 }
220
221 pub fn new_from_inputs(plan: P, inputs: TVec<TValue>) -> TractResult<SimpleState<F, O, M, P>> {
222 let mut state = SimpleState::new(plan)?;
223 state.set_inputs(inputs)?;
224 state.resolve_symbols_with_states()?;
225
226 Ok(state)
227 }
228
229 fn populate_consts(&mut self) {
230 for node in &self.plan.borrow().model().nodes {
231 if let Some(k) = node.op_as::<Const>() {
232 self.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
233 }
234 }
235 }
236
237 pub fn reset_turn(&mut self) -> TractResult<()> {
239 for node in &self.plan.borrow().order {
240 self.values[*node] = None;
241 }
242 self.session_state.resolved_symbols = SymbolValues::default();
243 Ok(())
244 }
245
246 pub fn reset_op_states(&mut self) -> TractResult<()> {
248 let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
249 for (ix, n) in plan.borrow().model().nodes().iter().enumerate() {
250 states[ix] =
251 if n.op().is_stateless() { None } else { n.op().state(session_state, ix)? };
252 }
253 Ok(())
254 }
255
256 pub fn init_states(&mut self, state_init_tensors: &mut Vec<TValue>) -> TractResult<()> {
257 let states_to_init = self
258 .states
259 .iter_mut()
260 .filter_map(Option::as_mut)
261 .filter(|s| s.init_tensor_fact().is_some())
262 .collect_vec();
263 ensure!(
264 states_to_init.len() == state_init_tensors.len(),
265 "There are {} op to init but got {} tensors",
266 states_to_init.len(),
267 state_init_tensors.len()
268 );
269 for state in states_to_init {
270 state.load_from(&mut self.session_state, state_init_tensors)?;
271 }
272 Ok(())
273 }
274
275 fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
276 for state in self
277 .states
278 .iter_mut()
279 .filter_map(Option::as_mut)
280 .filter(|s| s.init_tensor_fact().is_some())
281 {
282 state.resolve_symbols(&mut self.session_state)?;
283 }
284 Ok(())
285 }
286
287 pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
288 self.run_plan_with_eval(inputs, self::eval)
289 }
290
291 pub fn exec(&mut self) -> TractResult<()> {
292 self.exec_plan_with_eval(self::eval)
293 }
294
295 pub fn run_plan_with_eval<Eval, E>(
296 &mut self,
297 inputs: TVec<TValue>,
298 eval: Eval,
299 ) -> TractResult<TVec<TValue>>
300 where
301 Eval: for<'a, 'b, 'c> FnMut(
302 &'a mut SessionState,
303 Option<&'b mut (dyn OpState + 'static)>,
304 &'c Node<F, O>,
305 TVec<TValue>,
306 ) -> Result<TVec<TValue>, E>,
307 E: Into<anyhow::Error> + Send + Sync + 'static,
308 {
309 self.set_inputs(inputs)?;
310 self.resolve_symbols_with_states()?;
311 self.exec_plan_with_eval(eval)?;
312 let outputs = self.outputs()?;
313 self.reset_turn()?;
314 Ok(outputs)
315 }
316
317 pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
318 where
319 Eval: for<'a, 'b, 'c> FnMut(
320 &'a mut SessionState,
321 Option<&'b mut (dyn OpState + 'static)>,
322 &'c Node<F, O>,
323 TVec<TValue>,
324 ) -> Result<TVec<TValue>, E>,
325 E: Into<anyhow::Error> + Send + Sync + 'static,
326 {
327 if let Some(executor) = self.plan().executor.as_ref() {
328 tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
329 self.do_exec_plan_with_eval(eval)
330 })
331 } else {
332 self.do_exec_plan_with_eval(eval)
333 }
334 }
335
336 fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
337 where
338 Eval: for<'a, 'b, 'c> FnMut(
339 &'a mut SessionState,
340 Option<&'b mut (dyn OpState + 'static)>,
341 &'c Node<F, O>,
342 TVec<TValue>,
343 ) -> Result<TVec<TValue>, E>,
344 E: Into<anyhow::Error> + Send + Sync + 'static,
345 {
346 {
347 let plan = self.plan.borrow();
348 let model = plan.model.borrow();
349 plan.session_handler
350 .as_ref()
351 .map(|it| it.before_plan_eval(&mut self.session_state))
352 .transpose()?;
353
354 for (step, n) in plan.order.iter().enumerate() {
355 let node = model.node(*n);
356 trace!("Running step {step}, node {node}");
357 let mut inputs: TVec<TValue> = tvec![];
358 for i in &node.inputs {
359 trace!(" use input {i:?}");
360 let prec_node = model.node(i.node);
361 let prec = self.values[i.node].as_ref().ok_or_else(|| {
362 format_err!("Computing {}, precursor {} not done:", node, prec_node)
363 })?;
364 inputs.push(prec[i.slot].clone())
365 }
366
367 for flush in &plan.flush_lists[step] {
368 trace!(" Ran {} can now flush {}", node, model.node(*flush));
369 self.values[*flush] = None;
370 }
371
372 if cfg!(debug_assertions) {
373 let facts = model.node_input_facts(node.id)?;
374 if facts.len() != inputs.len() {
375 bail!(
376 "Evaluating {}: expected {} inputs, got {}",
377 node,
378 facts.len(),
379 inputs.len()
380 );
381 }
382 for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
383 if !f.matches(v, Some(&self.session_state.resolved_symbols))? {
384 bail!(
385 "Evaluating {}: input {:?}, expected {:?}, got {:?}",
386 node,
387 ix,
388 f,
389 v
390 );
391 }
392 }
393 }
394
395 let vs = eval(
396 &mut self.session_state,
397 self.states[node.id].as_deref_mut(),
398 node,
399 inputs,
400 )
401 .map_err(|e| e.into())?;
402
403 if plan.has_unresolved_symbols {
404 for (o, v) in node.outputs.iter().zip(vs.iter()) {
405 if let Ok(f) = o.fact.to_typed_fact() {
406 for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
407 Self::resolve(
408 &mut self.session_state,
409 dim_abstract,
410 *dim_concrete as i64,
411 )?;
412 }
413 }
414 }
415 }
416 if cfg!(debug_assertions) {
417 let facts = model.node_output_facts(node.id)?;
418 if facts.len() != vs.len() {
419 bail!(
420 "Evaluating {}: expected {} outputs, got {}",
421 node,
422 facts.len(),
423 vs.len()
424 );
425 }
426 for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
427 if node.outputs[ix].successors.len() == 0 {
428 continue;
429 }
430 if !f.matches(v, Some(&self.session_state.resolved_symbols))? {
431 bail!(
432 "Evaluating {}: output {:?}, expected {:?}, got {:?}",
433 node,
434 ix,
435 f,
436 v
437 );
438 }
439 }
440 }
441
442 self.values[node.id] = Some(vs);
443 }
444 plan.session_handler
445 .as_ref()
446 .map(|it| it.after_plan_eval(&mut self.session_state))
447 .transpose()?;
448 }
449 Ok(())
450 }
451
452 pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
453 ensure!(
454 inputs.len() == self.model().inputs.len(),
455 "Wrong number of inputs for model. Expected {} got {}",
456 self.model().inputs.len(),
457 inputs.len()
458 );
459
460 for (ix, t) in inputs.into_iter().enumerate() {
461 self.set_input(ix, t)?
462 }
463 Ok(())
464 }
465
466 fn resolve(state: &mut SessionState, expression: &TDim, provided: i64) -> TractResult<()> {
467 let expected = expression.eval(&state.resolved_symbols);
468 if let Ok(x) = expected.to_i64() {
469 if x != provided {
470 bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
471 }
472 }
473 if expected.symbols().len() == 1 {
474 let sym = expected.symbols().into_iter().next().unwrap();
475 if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
476 debug!("Determined symbol {sym}={v}");
477 state.resolved_symbols.set(&sym, v.to_i64().unwrap());
478 }
479 if state.scenario.is_none() {
480 let scope = sym
481 .scope()
482 .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
483 state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
484 }
485 }
486 Ok(())
487 }
488
489 pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
490 let outlet: OutletId = *self
491 .model()
492 .input_outlets()?
493 .get(input)
494 .with_context(|| format!("Invalid input id for model ({input})."))?;
495 if let Ok(fact) = self.plan.borrow().model().outlet_fact(outlet)?.to_typed_fact() {
496 for (expected, provided) in fact.shape.iter().zip(t.shape()) {
497 Self::resolve(&mut self.session_state, expected, *provided as i64)?;
498 }
499 }
500 let fact = self.plan.borrow().model().outlet_fact(outlet)?;
501 ensure!(
502 fact.matches(&t, Some(&self.session_state.resolved_symbols))
503 .with_context(|| format!("Setting input {input}"))?,
504 "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
505 );
506 self.session_state.inputs.insert(outlet.node, t);
507 Ok(())
508 }
509
510 pub fn output(&self, id: usize) -> TractResult<&TValue> {
511 let outlet = self.model().output_outlets()?.get(id).with_context(|| {
512 format!(
513 "Required output {}, only have {}",
514 id,
515 self.model().output_outlets().unwrap().len()
516 )
517 })?;
518 let value: &TValue = self
519 .values
520 .get(outlet.node)
521 .context("node id for output beyond node values array")?
522 .as_ref()
523 .context("node is not an output")?
524 .get(outlet.slot)
525 .context("slot id too high")?;
526 Ok(value)
527 }
528
529 pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
530 let &mut SimpleState { ref plan, ref mut values, .. } = self;
531 let mut v = tvec![];
532 for o in plan.borrow().outputs.iter() {
533 let vs = values[o.node].as_mut().ok_or_else(|| {
534 format_err!(
535 "Outputs of {:?} are not computed",
536 &plan.borrow().model().nodes()[o.node]
537 )
538 })?;
539 v.push(vs[o.slot].clone())
540 }
541 Ok(v)
542 }
543
544 pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
545 self.values[id] = Some(values);
546 Ok(())
547 }
548
549 pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
550 self.set_values(id, tvec!(value))
551 }
552
553 pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
554 let SimpleState { plan, values, .. } = self;
555 let plan = plan.borrow();
556 let nodes = plan.model().nodes();
557 let node = &nodes[node];
558 let mut inputs: TVec<TValue> = tvec![];
559 for i in &node.inputs {
560 let prec_node = &nodes[i.node];
561 let prec = values[i.node].as_ref().ok_or_else(|| {
562 format_err!("Computing {}, precursor {} not done.", node, prec_node)
563 })?;
564 inputs.push(prec[i.slot].clone())
565 }
566 Ok(inputs)
567 }
568
569 pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
570 let inputs = self.prepare_inputs(node)?;
571 self.compute_one_with_inputs(node, inputs)
572 }
573
574 pub fn compute_one_with_inputs(
575 &mut self,
576 node: usize,
577 inputs: TVec<TValue>,
578 ) -> TractResult<()> {
579 let &mut SimpleState {
580 ref plan,
581 ref mut session_state,
582 ref mut values,
583 ref mut states,
584 ..
585 } = self;
586 let plan = plan.borrow();
587 let nodes = plan.model().nodes();
588 let node = &nodes[node];
589 let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)?;
590 values[node.id] = Some(vs);
591 Ok(())
592 }
593
594 pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
595 let values = {
596 #[allow(clippy::needless_collect)] let precs: Vec<usize> =
598 self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
599 for i in precs.into_iter() {
600 if self.values[i].is_none() {
601 let _ = self.compute_recursively(i)?;
602 }
603 }
604 let mut inputs: TVec<TValue> = tvec![];
605 {
606 let node = &self.model().nodes()[node];
607 for i in &node.inputs {
608 inputs.push(self.values[i.node].as_ref().unwrap()[i.slot].clone())
609 }
610 }
611 let &mut Self { ref mut states, ref mut session_state, ref plan, .. } = self;
612 eval(
613 session_state,
614 states[node].as_deref_mut(),
615 &plan.borrow().model().nodes[node],
616 inputs,
617 )?
618 };
619 self.values[node] = Some(values);
620 Ok(self.values[node].as_ref().unwrap())
621 }
622
623 pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
624 let id = self.model().node_by_name(name)?.id;
625 Self::take(self, id)
626 }
627
628 pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
629 Ok(self.values[id]
630 .take()
631 .ok_or_else(|| format_err!("Node is not computed"))?
632 .into_iter()
633 .map(|v| v.into_tensor())
634 .collect())
635 }
636
637 pub fn plan(&self) -> &SimplePlan<F, O, M> {
638 self.plan.borrow()
639 }
640
641 pub fn model(&self) -> &Graph<F, O> {
642 self.plan().model()
643 }
644
645 pub fn freeze(&self) -> FrozenSimpleState<F, O, M, P> {
646 FrozenSimpleState {
647 plan: self.plan.clone(),
648 inputs: self
649 .session_state
650 .inputs
651 .iter()
652 .map(|(ix, t)| (*ix, t.clone().into_tensor()))
653 .collect(),
654 resolved_symbols: self.session_state.resolved_symbols.clone(),
655 scenario: self.session_state.scenario,
656 tensors: self.session_state.tensors.clone(),
657 states: self.states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
658 values: self
659 .values
660 .iter()
661 .enumerate()
662 .map(|(ix, t)| {
663 if self.model().nodes[ix].op_is::<Const>() {
664 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
665 } else {
666 None
667 }
668 })
669 .collect(),
670 _phantom: PhantomData,
671 }
672 }
673}
674
675pub fn eval<F, O>(
676 session_state: &mut SessionState,
677 mut state: Option<&mut (dyn OpState + 'static)>,
678 node: &Node<F, O>,
679 input: TVec<TValue>,
680) -> TractResult<TVec<TValue>>
681where
682 F: Fact + Clone + 'static,
683 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
684{
685 #[allow(clippy::let_and_return)]
687 let r = match state {
688 Some(ref mut state) => state.eval(session_state, node.op(), input),
689 None => node.op().eval_with_session(node.id, session_state, input),
690 }
691 .with_context(|| format!("Evaluating {node}"));
692 r
694}
695
696#[derive(Clone, Debug)]
697pub struct FrozenSimpleState<F, O, M, P>
698where
699 F: Fact + Clone + 'static,
700 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
701 M: Borrow<Graph<F, O>>,
702 P: Borrow<SimplePlan<F, O, M>> + Clone,
703{
704 plan: P,
705 pub inputs: HashMap<usize, Tensor>,
706 pub resolved_symbols: SymbolValues,
707 pub scenario: Option<usize>,
708 pub tensors: HashMap<String, Tensor>,
709 pub states: Vec<Option<Box<dyn FrozenOpState>>>,
710 pub values: Vec<Option<TVec<Tensor>>>,
711 _phantom: PhantomData<(M, F, O)>,
712}
713
714impl<F, O, M, P> FrozenSimpleState<F, O, M, P>
715where
716 F: Fact + Clone + 'static,
717 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
718 M: Borrow<Graph<F, O>>,
719 P: Borrow<SimplePlan<F, O, M>> + Clone,
720{
721 pub fn unfreeze(&self) -> SimpleState<F, O, M, P> {
722 let mut state = SimpleState {
723 plan: self.plan.clone(),
724 session_state: SessionState {
725 inputs: self.inputs.iter().map(|(ix, t)| (*ix, t.clone().into_tvalue())).collect(),
726 resolved_symbols: self.resolved_symbols.clone(),
727 scenario: self.scenario,
728 tensors: self.tensors.clone(),
729 cached_mmm_scratch_space: None.into(),
730 scratch_extensions: anymap3::Map::new(),
731 },
732 states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
733 values: self
734 .values
735 .iter()
736 .map(|t| t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect()))
737 .collect(),
738 _phantom: PhantomData,
739 };
740 state.populate_consts();
741 state
742 }
743}
744
745#[cfg(test)]
746mod test {
747 use super::*;
748 fn is_send<T: Send>() {}
749 fn is_sync<T: Sync>() {}
750
751 #[test]
752 fn type_model_is_sync() {
753 is_sync::<TypedModel>();
754 }
755
756 #[test]
757 fn type_model_is_send() {
758 is_send::<TypedModel>();
759 }
760
761 #[test]
762 fn type_plan_is_send() {
763 is_send::<TypedSimplePlan<TypedModel>>();
764 }
765
766 #[test]
767 fn type_plan_is_sync() {
768 is_sync::<TypedSimplePlan<TypedModel>>();
769 }
770
771 #[test]
772 fn frozen_type_state_is_send() {
773 is_send::<TypedFrozenSimpleState<TypedModel, TypedSimplePlan<TypedModel>>>();
774 }
775}