1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6use tract_data::itertools::Itertools;
7
8use crate::internal::*;
9use crate::model::{Fact, Graph, OutletId};
10use crate::ops::FrozenOpState;
11use crate::ops::konst::Const;
12use crate::runtime::RunOptions;
13
14use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
15
16pub struct TurnState {
17 pub resolved_symbols: SymbolValues,
18 pub scenario: Option<usize>,
19 pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
20 pub scratch_extensions: anymap3::Map,
21 pub values: Vec<Option<TVec<TValue>>>,
22}
23
24impl Default for TurnState {
25 fn default() -> Self {
26 TurnState {
27 resolved_symbols: SymbolValues::default(),
28 scenario: None,
29 cached_mmm_scratch_space: None.into(),
30 scratch_extensions: anymap3::Map::new(),
31 values: vec![],
32 }
33 }
34}
35
36impl Clone for TurnState {
37 fn clone(&self) -> Self {
38 TurnState {
39 resolved_symbols: self.resolved_symbols.clone(),
40 scenario: self.scenario,
41 cached_mmm_scratch_space: None.into(),
42 scratch_extensions: anymap3::Map::new(),
43 values: vec![],
44 }
45 }
46}
47
48pub trait SessionStateHandler: Send + Sync + Debug {
49 fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50 fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
51}
52
53impl Debug for TurnState {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "SessionState({:?})", self.resolved_symbols)
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct SimplePlan<F, O>
61where
62 F: Fact + Clone + 'static,
63 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
64{
65 pub(crate) model: Arc<Graph<F, O>>,
66 outputs: Vec<OutletId>,
67 order: Vec<usize>,
68 flush_lists: Vec<TVec<usize>>,
69 has_unresolved_symbols: bool,
70 executor: Option<Executor>,
71 session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
72}
73
74impl<F, O> SimplePlan<F, O>
75where
76 F: Fact + Clone + 'static,
77 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
78{
79 pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
81 let model = model.into();
82 Self::build(model, &RunOptions::default()).map(Arc::new)
83 }
84
85 pub fn new_with_options(
87 model: impl Into<Arc<Graph<F, O>>>,
88 options: &RunOptions,
89 ) -> TractResult<Arc<SimplePlan<F, O>>> {
90 let model = model.into();
91 Self::build(model, options).map(Arc::new)
92 }
93
94 #[deprecated]
96 pub fn new_for_output(
97 model: Graph<F, O>,
98 output: OutletId,
99 ) -> TractResult<Arc<SimplePlan<F, O>>> {
100 #[allow(deprecated)]
101 Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
102 .map(Arc::new)
103 }
104
105 #[deprecated]
107 pub fn new_for_outputs(
108 model: impl Into<Arc<Graph<F, O>>>,
109 outputs: &[OutletId],
110 ) -> TractResult<Arc<SimplePlan<F, O>>> {
111 #[allow(deprecated)]
112 Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
113 }
114
115 pub fn with_session_handler<H: SessionStateHandler + 'static>(
116 mut self,
117 session_handler: H,
118 ) -> Self {
119 self.session_handler = Some(Arc::new(session_handler));
120 self
121 }
122
123 #[deprecated]
124 pub fn new_for_outputs_and_deps(
125 model: impl Into<Arc<Graph<F, O>>>,
126 outputs: &[OutletId],
127 deps: &[(usize, usize)],
128 ) -> TractResult<Arc<SimplePlan<F, O>>> {
129 #[allow(deprecated)]
130 Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
131 .map(Arc::new)
132 }
133
134 pub fn build(
135 model: impl Into<Arc<Graph<F, O>>>,
136 options: &RunOptions,
137 ) -> TractResult<SimplePlan<F, O>> {
138 let model = model.into();
139 let outputs = model.outputs.clone();
140 #[allow(deprecated)]
141 Self::build_with_outputs_and_deps(model, &outputs, &[], options)
142 }
143
144 #[deprecated]
145 pub fn build_with_outputs_and_deps(
146 model: impl Into<Arc<Graph<F, O>>>,
147 outputs: &[OutletId],
148 deps: &[(usize, usize)],
149 options: &RunOptions,
150 ) -> TractResult<SimplePlan<F, O>> {
151 let model = model.into();
152 let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
153 let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
154 let mut order = if options.skip_order_opt_ram {
155 eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
156 } else {
157 eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
158 };
159 order.retain(|node| !model.node(*node).op_is::<Const>());
160 let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
161
162 #[allow(clippy::mutable_key_type)]
163 let mut symbols: std::collections::HashSet<Symbol> = Default::default();
164 for node in &model.nodes {
165 for output in &node.outputs {
166 if let Ok(fact) = output.fact.to_typed_fact() {
167 symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
168 }
169 }
170 }
171 Ok(SimplePlan {
172 model,
173 order,
174 flush_lists,
175 outputs: outputs.to_vec(),
176 has_unresolved_symbols: !symbols.is_empty(),
177 executor: options.executor.clone(),
178 session_handler: None,
179 })
180 }
181
182 pub fn order_without_consts(&self) -> &[usize] {
183 &self.order
184 }
185
186 pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
187 let mut state = self.spawn()?;
188 state.run(inputs)
189 }
190
191 pub fn model(&self) -> &Graph<F, O> {
192 self.model.borrow()
193 }
194
195 pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
196 SimpleState::new(self)
197 }
198}
199
200#[derive(Clone, Debug)]
201pub struct SimpleState<F, O>
202where
203 F: Fact + Clone + 'static,
204 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
205{
206 pub(crate) plan: Arc<SimplePlan<F, O>>,
207 pub op_states: Vec<Option<Box<dyn OpState>>>,
208 pub turn_state: TurnState,
209}
210
211impl<F, O> SimpleState<F, O>
212where
213 F: Fact + Clone + 'static,
214 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
215{
216 pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
217 let plan = Arc::clone(plan);
218 let turn = TurnState::default();
219 let model = plan.model();
220 let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
221 let mut state = SimpleState { plan, op_states: states, turn_state: turn };
222 state.reset_op_states()?;
223 Ok(state)
224 }
225
226 pub fn new_from_inputs(
227 plan: &Arc<SimplePlan<F, O>>,
228 inputs: TVec<TValue>,
229 ) -> TractResult<SimpleState<F, O>> {
230 let mut state = SimpleState::new(plan)?;
231 state.set_inputs(inputs)?;
232 state.resolve_symbols_with_states()?;
233
234 Ok(state)
235 }
236
237 fn ready_turn(&mut self) {
238 if self.turn_state.values.len() == 0 {
239 self.turn_state.values = vec![None; self.plan.model.nodes().len()];
240 for node in &self.plan.model.nodes {
241 if let Some(k) = node.op_as::<Const>() {
242 self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
243 }
244 }
245 }
246 }
247 pub fn reset_turn(&mut self) -> TractResult<()> {
249 for node in &self.plan.order {
250 self.turn_state.values[*node] = None;
251 }
252 self.turn_state.resolved_symbols = SymbolValues::default();
253 Ok(())
254 }
255
256 fn reset_op_states(&mut self) -> TractResult<()> {
258 let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
259 for (ix, n) in plan.model.nodes.iter().enumerate() {
260 states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
261 }
262 Ok(())
263 }
264
265 pub fn init_states(&mut self, state_init_tensors: &[TValue]) -> TractResult<()> {
266 let states_to_init = self
267 .op_states
268 .iter_mut()
269 .filter_map(Option::as_mut)
270 .filter(|s| s.init_tensor_fact().is_some())
271 .collect_vec();
272 ensure!(
273 states_to_init.len() == state_init_tensors.len(),
274 "There are {} op to init but got {} tensors",
275 states_to_init.len(),
276 state_init_tensors.len()
277 );
278 let mut iterator = state_init_tensors.iter().cloned();
279 for state in states_to_init {
280 state.load_from(&mut self.turn_state, &mut iterator)?;
281 }
282 Ok(())
283 }
284
285 fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
286 for state in self
287 .op_states
288 .iter_mut()
289 .filter_map(Option::as_mut)
290 .filter(|s| s.init_tensor_fact().is_some())
291 {
292 state.resolve_symbols(&mut self.turn_state)?;
293 }
294 Ok(())
295 }
296
297 pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
298 self.run_plan_with_eval(inputs, self::eval)
299 }
300
301 pub fn exec(&mut self) -> TractResult<()> {
302 self.exec_plan_with_eval(self::eval)
303 }
304
305 pub fn run_plan_with_eval<Eval, E>(
306 &mut self,
307 inputs: TVec<TValue>,
308 eval: Eval,
309 ) -> TractResult<TVec<TValue>>
310 where
311 Eval: for<'a, 'b, 'c> FnMut(
312 &'a mut TurnState,
313 Option<&'b mut (dyn OpState + 'static)>,
314 &'c Node<F, O>,
315 TVec<TValue>,
316 ) -> Result<TVec<TValue>, E>,
317 E: Into<anyhow::Error> + Send + Sync + 'static,
318 {
319 self.set_inputs(inputs)?;
320 self.resolve_symbols_with_states()?;
321 self.exec_plan_with_eval(eval)?;
322 let outputs = self.outputs()?;
323 self.reset_turn()?;
324 Ok(outputs)
325 }
326
327 pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
328 where
329 Eval: for<'a, 'b, 'c> FnMut(
330 &'a mut TurnState,
331 Option<&'b mut (dyn OpState + 'static)>,
332 &'c Node<F, O>,
333 TVec<TValue>,
334 ) -> Result<TVec<TValue>, E>,
335 E: Into<anyhow::Error> + Send + Sync + 'static,
336 {
337 if let Some(executor) = self.plan().executor.as_ref() {
338 tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
339 self.do_exec_plan_with_eval(eval)
340 })
341 } else {
342 self.do_exec_plan_with_eval(eval)
343 }
344 }
345
346 fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
347 where
348 Eval: for<'a, 'b, 'c> FnMut(
349 &'a mut TurnState,
350 Option<&'b mut (dyn OpState + 'static)>,
351 &'c Node<F, O>,
352 TVec<TValue>,
353 ) -> Result<TVec<TValue>, E>,
354 E: Into<anyhow::Error> + Send + Sync + 'static,
355 {
356 {
357 self.ready_turn();
358 self.plan
359 .session_handler
360 .as_ref()
361 .map(|it| it.before_plan_eval(&mut self.turn_state))
362 .transpose()?;
363
364 for (step, n) in self.plan.order.iter().enumerate() {
365 let node = self.plan.model.node(*n);
366 trace!("Running step {step}, node {node}");
367 let mut inputs: TVec<TValue> = tvec![];
368 for i in &node.inputs {
369 trace!(" use input {i:?}");
370 let prec_node = self.plan.model.node(i.node);
371 let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
372 format_err!("Computing {}, precursor {} not done:", node, prec_node)
373 })?;
374 inputs.push(prec[i.slot].clone())
375 }
376
377 for flush in &self.plan.flush_lists[step] {
378 trace!(" Ran {} can now flush {}", node, self.plan.model.node(*flush));
379 self.turn_state.values[*flush] = None;
380 }
381
382 if cfg!(debug_assertions) {
383 let facts = self.plan.model.node_input_facts(node.id)?;
384 if facts.len() != inputs.len() {
385 bail!(
386 "Evaluating {}: expected {} inputs, got {}",
387 node,
388 facts.len(),
389 inputs.len()
390 );
391 }
392 for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
393 if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
394 bail!(
395 "Evaluating {}: input {:?}, expected {:?}, got {:?}",
396 node,
397 ix,
398 f,
399 v
400 );
401 }
402 }
403 }
404
405 let vs = eval(
406 &mut self.turn_state,
407 self.op_states[node.id].as_deref_mut(),
408 node,
409 inputs,
410 )
411 .map_err(|e| e.into())?;
412
413 if self.plan.has_unresolved_symbols {
414 for (o, v) in node.outputs.iter().zip(vs.iter()) {
415 if let Ok(f) = o.fact.to_typed_fact() {
416 for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
417 Self::resolve(
418 &mut self.turn_state,
419 dim_abstract,
420 *dim_concrete as i64,
421 )?;
422 }
423 }
424 }
425 }
426 if cfg!(debug_assertions) {
427 let facts = self.plan.model.node_output_facts(node.id)?;
428 if facts.len() != vs.len() {
429 bail!(
430 "Evaluating {}: expected {} outputs, got {}",
431 node,
432 facts.len(),
433 vs.len()
434 );
435 }
436 for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
437 if node.outputs[ix].successors.len() == 0 {
438 continue;
439 }
440 if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
441 bail!(
442 "Evaluating {}: output {:?}, expected {:?}, got {:?}",
443 node,
444 ix,
445 f,
446 v
447 );
448 }
449 }
450 }
451
452 self.turn_state.values[node.id] = Some(vs);
453 }
454 self.plan
455 .session_handler
456 .as_ref()
457 .map(|it| it.after_plan_eval(&mut self.turn_state))
458 .transpose()?;
459 }
460 Ok(())
461 }
462
463 pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
464 ensure!(
465 inputs.len() == self.model().inputs.len(),
466 "Wrong number of inputs for model. Expected {} got {}",
467 self.model().inputs.len(),
468 inputs.len()
469 );
470
471 for (ix, t) in inputs.into_iter().enumerate() {
472 self.set_input(ix, t)?
473 }
474 Ok(())
475 }
476
477 fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
478 let expected = expression.eval(&state.resolved_symbols);
479 if let Ok(x) = expected.to_i64() {
480 if x != provided {
481 bail!(
482 "Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})"
483 )
484 }
485 }
486 if expected.symbols().len() == 1 {
487 let sym = expected.symbols().into_iter().next().unwrap();
488 if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
489 debug!("Determined symbol {sym}={v}");
490 state.resolved_symbols.set(&sym, v.to_i64().unwrap());
491 }
492 if state.scenario.is_none() {
493 let scope = sym
494 .scope()
495 .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
496 state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
497 }
498 }
499 Ok(())
500 }
501
502 pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
503 let outlet: OutletId = *self
504 .model()
505 .input_outlets()?
506 .get(input)
507 .with_context(|| format!("Invalid input id for model ({input})."))?;
508 if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
509 for (expected, provided) in fact.shape.iter().zip(t.shape()) {
510 Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
511 }
512 }
513 let fact = self.plan.model.outlet_fact(outlet)?;
514 ensure!(
515 fact.matches(&t, Some(&self.turn_state.resolved_symbols))
516 .with_context(|| format!("Setting input {input}"))?,
517 "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
518 );
519 self.ready_turn();
520 self.turn_state.values[outlet.node] = Some(tvec!(t));
521 Ok(())
522 }
523
524 pub fn output(&self, id: usize) -> TractResult<&TValue> {
525 let outlet = self.model().output_outlets()?.get(id).with_context(|| {
526 format!(
527 "Required output {}, only have {}",
528 id,
529 self.model().output_outlets().unwrap().len()
530 )
531 })?;
532 let value: &TValue = self
533 .turn_state
534 .values
535 .get(outlet.node)
536 .context("node id for output beyond node values array")?
537 .as_ref()
538 .context("node is not an output")?
539 .get(outlet.slot)
540 .context("slot id too high")?;
541 Ok(value)
542 }
543
544 pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
545 let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
546 let mut v = tvec![];
547 for o in plan.outputs.iter() {
548 let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
549 format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
550 })?;
551 v.push(vs[o.slot].clone())
552 }
553 Ok(v)
554 }
555
556 pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
557 self.turn_state.values[id] = Some(values);
558 Ok(())
559 }
560
561 pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
562 self.set_values(id, tvec!(value))
563 }
564
565 pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
566 let SimpleState { plan, turn_state, .. } = self;
567 let nodes = plan.model.nodes();
568 let node = &nodes[node];
569 let mut inputs: TVec<TValue> = tvec![];
570 for i in &node.inputs {
571 let prec_node = &nodes[i.node];
572 let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
573 format_err!("Computing {}, precursor {} not done.", node, prec_node)
574 })?;
575 inputs.push(prec[i.slot].clone())
576 }
577 Ok(inputs)
578 }
579
580 pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
581 let inputs = self.prepare_inputs(node)?;
582 self.compute_one_with_inputs(node, inputs)
583 }
584
585 pub fn compute_one_with_inputs(
586 &mut self,
587 node: usize,
588 inputs: TVec<TValue>,
589 ) -> TractResult<()> {
590 let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
591 let nodes = plan.model.nodes();
592 let node = &nodes[node];
593 let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
594 turn_state.values[node.id] = Some(vs);
595 Ok(())
596 }
597
598 pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
599 let values = {
600 #[allow(clippy::needless_collect)] let precs: Vec<usize> =
602 self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
603 for i in precs.into_iter() {
604 if self.turn_state.values[i].is_none() {
605 let _ = self.compute_recursively(i)?;
606 }
607 }
608 let mut inputs: TVec<TValue> = tvec![];
609 {
610 let node = &self.model().nodes()[node];
611 for i in &node.inputs {
612 inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
613 }
614 }
615 let &mut Self {
616 op_states: ref mut states,
617 turn_state: ref mut session_state,
618 ref plan,
619 ..
620 } = self;
621 eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
622 };
623 self.turn_state.values[node] = Some(values);
624 Ok(self.turn_state.values[node].as_ref().unwrap())
625 }
626
627 pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
628 let id = self.model().node_by_name(name)?.id;
629 Self::take(self, id)
630 }
631
632 pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
633 Ok(self.turn_state.values[id]
634 .take()
635 .ok_or_else(|| format_err!("Node is not computed"))?
636 .into_iter()
637 .map(|v| v.into_tensor())
638 .collect())
639 }
640
641 pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
642 &self.plan
643 }
644
645 pub fn model(&self) -> &Graph<F, O> {
646 &self.plan.model
647 }
648
649 pub fn freeze(&self) -> FrozenSimpleState<F, O> {
650 FrozenSimpleState {
651 plan: self.plan.clone(),
652 resolved_symbols: self.turn_state.resolved_symbols.clone(),
653 scenario: self.turn_state.scenario,
654 states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
655 values: self
656 .turn_state
657 .values
658 .iter()
659 .enumerate()
660 .map(|(ix, t)| {
661 if self.model().nodes[ix].op_is::<Const>() {
662 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
663 } else {
664 None
665 }
666 })
667 .collect(),
668 }
669 }
670}
671
672pub fn eval<F, O>(
673 session_state: &mut TurnState,
674 mut state: Option<&mut (dyn OpState + 'static)>,
675 node: &Node<F, O>,
676 input: TVec<TValue>,
677) -> TractResult<TVec<TValue>>
678where
679 F: Fact + Clone + 'static,
680 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
681{
682 #[allow(clippy::let_and_return)]
684 let r = match state {
685 Some(ref mut state) => state.eval(session_state, node.op(), input),
686 None => node.op().eval_with_session(node.id, session_state, input),
687 }
688 .with_context(|| format!("Evaluating {node}"));
689 r
691}
692
693#[derive(Clone, Debug)]
694pub struct FrozenSimpleState<F, O>
695where
696 F: Fact + Clone + 'static,
697 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
698{
699 plan: Arc<SimplePlan<F, O>>,
700 pub resolved_symbols: SymbolValues,
701 pub scenario: Option<usize>,
702 pub states: Vec<Option<Box<dyn FrozenOpState>>>,
703 pub values: Vec<Option<TVec<Tensor>>>,
704}
705
706impl<F, O> FrozenSimpleState<F, O>
707where
708 F: Fact + Clone + 'static,
709 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
710{
711 pub fn unfreeze(&self) -> SimpleState<F, O> {
712 SimpleState {
713 plan: self.plan.clone(),
714 turn_state: TurnState {
715 resolved_symbols: self.resolved_symbols.clone(),
716 scenario: self.scenario,
717 cached_mmm_scratch_space: None.into(),
718 scratch_extensions: anymap3::Map::new(),
719 values: self
720 .values
721 .iter()
722 .map(|t| {
723 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
724 })
725 .collect(),
726 },
727 op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
728 }
729 }
730}
731
732#[cfg(test)]
733mod test {
734 use super::*;
735 fn is_send<T: Send>() {}
736 fn is_sync<T: Sync>() {}
737
738 #[test]
739 fn type_model_is_sync() {
740 is_sync::<TypedModel>();
741 }
742
743 #[test]
744 fn type_model_is_send() {
745 is_send::<TypedModel>();
746 }
747
748 #[test]
749 fn type_plan_is_send() {
750 is_send::<TypedSimplePlan>();
751 }
752
753 #[test]
754 fn type_plan_is_sync() {
755 is_sync::<TypedSimplePlan>();
756 }
757
758 #[test]
759 fn frozen_type_state_is_send() {
760 is_send::<TypedFrozenSimpleState>();
761 }
762}