1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6
7use crate::internal::*;
8use crate::model::{Fact, Graph, OutletId};
9use crate::ops::FrozenOpState;
10use crate::ops::konst::Const;
11use crate::runtime::RunOptions;
12
13use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
14
15pub struct TurnState {
16 pub resolved_symbols: SymbolValues,
17 pub scenario: Option<usize>,
18 pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
19 pub scratch_extensions: anymap3::Map,
20 pub values: Vec<Option<TVec<TValue>>>,
21}
22
23impl Default for TurnState {
24 fn default() -> Self {
25 TurnState {
26 resolved_symbols: SymbolValues::default(),
27 scenario: None,
28 cached_mmm_scratch_space: None.into(),
29 scratch_extensions: anymap3::Map::new(),
30 values: vec![],
31 }
32 }
33}
34
35impl Clone for TurnState {
36 fn clone(&self) -> Self {
37 TurnState {
38 resolved_symbols: self.resolved_symbols.clone(),
39 scenario: self.scenario,
40 cached_mmm_scratch_space: None.into(),
41 scratch_extensions: anymap3::Map::new(),
42 values: vec![],
43 }
44 }
45}
46
47pub trait SessionStateHandler: Send + Sync + Debug {
48 fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
49 fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50}
51
52impl Debug for TurnState {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "SessionState({:?})", self.resolved_symbols)
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct SimplePlan<F, O>
60where
61 F: Fact + Clone + 'static,
62 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
63{
64 pub(crate) model: Arc<Graph<F, O>>,
65 outputs: Vec<OutletId>,
66 order: Vec<usize>,
67 flush_lists: Vec<TVec<usize>>,
68 has_unresolved_symbols: bool,
69 executor: Option<Executor>,
70 session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
71}
72
73impl<F, O> SimplePlan<F, O>
74where
75 F: Fact + Clone + 'static,
76 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
77{
78 pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
80 let model = model.into();
81 Self::build(model, &RunOptions::default()).map(Arc::new)
82 }
83
84 pub fn new_with_options(
86 model: impl Into<Arc<Graph<F, O>>>,
87 options: &RunOptions,
88 ) -> TractResult<Arc<SimplePlan<F, O>>> {
89 let model = model.into();
90 Self::build(model, options).map(Arc::new)
91 }
92
93 #[deprecated]
95 pub fn new_for_output(
96 model: Graph<F, O>,
97 output: OutletId,
98 ) -> TractResult<Arc<SimplePlan<F, O>>> {
99 #[allow(deprecated)]
100 Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
101 .map(Arc::new)
102 }
103
104 #[deprecated]
106 pub fn new_for_outputs(
107 model: impl Into<Arc<Graph<F, O>>>,
108 outputs: &[OutletId],
109 ) -> TractResult<Arc<SimplePlan<F, O>>> {
110 #[allow(deprecated)]
111 Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
112 }
113
114 pub fn with_session_handler<H: SessionStateHandler + 'static>(
115 mut self,
116 session_handler: H,
117 ) -> Self {
118 self.session_handler = Some(Arc::new(session_handler));
119 self
120 }
121
122 #[deprecated]
123 pub fn new_for_outputs_and_deps(
124 model: impl Into<Arc<Graph<F, O>>>,
125 outputs: &[OutletId],
126 deps: &[(usize, usize)],
127 ) -> TractResult<Arc<SimplePlan<F, O>>> {
128 #[allow(deprecated)]
129 Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
130 .map(Arc::new)
131 }
132
133 pub fn build(
134 model: impl Into<Arc<Graph<F, O>>>,
135 options: &RunOptions,
136 ) -> TractResult<SimplePlan<F, O>> {
137 let model = model.into();
138 let outputs = model.outputs.clone();
139 #[allow(deprecated)]
140 Self::build_with_outputs_and_deps(model, &outputs, &[], options)
141 }
142
143 #[deprecated]
144 pub fn build_with_outputs_and_deps(
145 model: impl Into<Arc<Graph<F, O>>>,
146 outputs: &[OutletId],
147 deps: &[(usize, usize)],
148 options: &RunOptions,
149 ) -> TractResult<SimplePlan<F, O>> {
150 let model = model.into();
151 let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
152 let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
153 let mut order = if options.skip_order_opt_ram {
154 eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
155 } else {
156 eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
157 };
158 order.retain(|node| !model.node(*node).op_is::<Const>());
159 let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
160
161 #[allow(clippy::mutable_key_type)]
162 let mut symbols: std::collections::HashSet<Symbol> = Default::default();
163 for node in &model.nodes {
164 for output in &node.outputs {
165 if let Ok(fact) = output.fact.to_typed_fact() {
166 symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
167 }
168 }
169 }
170 Ok(SimplePlan {
171 model,
172 order,
173 flush_lists,
174 outputs: outputs.to_vec(),
175 has_unresolved_symbols: !symbols.is_empty(),
176 executor: options.executor.clone(),
177 session_handler: None,
178 })
179 }
180
181 pub fn order_without_consts(&self) -> &[usize] {
182 &self.order
183 }
184
185 pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
186 let mut state = self.spawn()?;
187 state.run(inputs)
188 }
189
190 pub fn model(&self) -> &Graph<F, O> {
191 self.model.borrow()
192 }
193
194 pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
195 SimpleState::new(self)
196 }
197}
198
199#[derive(Clone, Debug)]
200pub struct SimpleState<F, O>
201where
202 F: Fact + Clone + 'static,
203 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
204{
205 pub(crate) plan: Arc<SimplePlan<F, O>>,
206 pub op_states: Vec<Option<Box<dyn OpState>>>,
207 pub turn_state: TurnState,
208}
209
210impl<F, O> SimpleState<F, O>
211where
212 F: Fact + Clone + 'static,
213 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
214{
215 pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
216 let plan = Arc::clone(plan);
217 let turn = TurnState::default();
218 let model = plan.model();
219 let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
220 let mut state = SimpleState { plan, op_states: states, turn_state: turn };
221 state.reset_op_states()?;
222 Ok(state)
223 }
224
225 pub fn new_from_inputs(
226 plan: &Arc<SimplePlan<F, O>>,
227 inputs: TVec<TValue>,
228 ) -> TractResult<SimpleState<F, O>> {
229 let mut state = SimpleState::new(plan)?;
230 state.set_inputs(inputs)?;
231 state.resolve_symbols_with_states()?;
232
233 Ok(state)
234 }
235
236 fn ready_turn(&mut self) {
237 if self.turn_state.values.len() == 0 {
238 self.turn_state.values = vec![None; self.plan.model.nodes().len()];
239 for node in &self.plan.model.nodes {
240 if let Some(k) = node.op_as::<Const>() {
241 self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
242 }
243 }
244 }
245 }
246 pub fn reset_turn(&mut self) -> TractResult<()> {
248 for node in &self.plan.order {
249 self.turn_state.values[*node] = None;
250 }
251 self.turn_state.resolved_symbols = SymbolValues::default();
252 Ok(())
253 }
254
255 fn reset_op_states(&mut self) -> TractResult<()> {
257 let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
258 for (ix, n) in plan.model.nodes.iter().enumerate() {
259 states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
260 }
261 Ok(())
262 }
263
264 fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
265 for state in self
266 .op_states
267 .iter_mut()
268 .filter_map(Option::as_mut)
269 .filter(|s| s.init_tensor_fact().is_some())
270 {
271 state.resolve_symbols(&mut self.turn_state)?;
272 }
273 Ok(())
274 }
275
276 pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
277 self.run_plan_with_eval(inputs, self::eval)
278 }
279
280 pub fn exec(&mut self) -> TractResult<()> {
281 self.exec_plan_with_eval(self::eval)
282 }
283
284 pub fn run_plan_with_eval<Eval, E>(
285 &mut self,
286 inputs: TVec<TValue>,
287 eval: Eval,
288 ) -> TractResult<TVec<TValue>>
289 where
290 Eval: for<'a, 'b, 'c> FnMut(
291 &'a mut TurnState,
292 Option<&'b mut (dyn OpState + 'static)>,
293 &'c Node<F, O>,
294 TVec<TValue>,
295 ) -> Result<TVec<TValue>, E>,
296 E: Into<anyhow::Error> + Send + Sync + 'static,
297 {
298 self.set_inputs(inputs)?;
299 self.resolve_symbols_with_states()?;
300 self.exec_plan_with_eval(eval)?;
301 let outputs = self.outputs()?;
302 self.reset_turn()?;
303 Ok(outputs)
304 }
305
306 pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
307 where
308 Eval: for<'a, 'b, 'c> FnMut(
309 &'a mut TurnState,
310 Option<&'b mut (dyn OpState + 'static)>,
311 &'c Node<F, O>,
312 TVec<TValue>,
313 ) -> Result<TVec<TValue>, E>,
314 E: Into<anyhow::Error> + Send + Sync + 'static,
315 {
316 if let Some(executor) = self.plan().executor.as_ref() {
317 tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
318 self.do_exec_plan_with_eval(eval)
319 })
320 } else {
321 self.do_exec_plan_with_eval(eval)
322 }
323 }
324
325 fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
326 where
327 Eval: for<'a, 'b, 'c> FnMut(
328 &'a mut TurnState,
329 Option<&'b mut (dyn OpState + 'static)>,
330 &'c Node<F, O>,
331 TVec<TValue>,
332 ) -> Result<TVec<TValue>, E>,
333 E: Into<anyhow::Error> + Send + Sync + 'static,
334 {
335 {
336 self.ready_turn();
337 self.plan
338 .session_handler
339 .as_ref()
340 .map(|it| it.before_plan_eval(&mut self.turn_state))
341 .transpose()?;
342
343 for (step, n) in self.plan.order.iter().enumerate() {
344 let node = self.plan.model.node(*n);
345 trace!("Running step {step}, node {node}");
346 let mut inputs: TVec<TValue> = tvec![];
347 for i in &node.inputs {
348 trace!(" use input {i:?}");
349 let prec_node = self.plan.model.node(i.node);
350 let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
351 format_err!("Computing {}, precursor {} not done:", node, prec_node)
352 })?;
353 inputs.push(prec[i.slot].clone())
354 }
355
356 for flush in &self.plan.flush_lists[step] {
357 trace!(" Ran {} can now flush {}", node, self.plan.model.node(*flush));
358 self.turn_state.values[*flush] = None;
359 }
360
361 if cfg!(debug_assertions) {
362 let facts = self.plan.model.node_input_facts(node.id)?;
363 if facts.len() != inputs.len() {
364 bail!(
365 "Evaluating {}: expected {} inputs, got {}",
366 node,
367 facts.len(),
368 inputs.len()
369 );
370 }
371 for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
372 if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
373 bail!(
374 "Evaluating {}: input {:?}, expected {:?}, got {:?}",
375 node,
376 ix,
377 f,
378 v
379 );
380 }
381 }
382 }
383
384 let vs = eval(
385 &mut self.turn_state,
386 self.op_states[node.id].as_deref_mut(),
387 node,
388 inputs,
389 )
390 .map_err(|e| e.into())?;
391
392 if self.plan.has_unresolved_symbols {
393 for (o, v) in node.outputs.iter().zip(vs.iter()) {
394 if let Ok(f) = o.fact.to_typed_fact() {
395 for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
396 Self::resolve(
397 &mut self.turn_state,
398 dim_abstract,
399 *dim_concrete as i64,
400 )?;
401 }
402 }
403 }
404 }
405 if cfg!(debug_assertions) {
406 let facts = self.plan.model.node_output_facts(node.id)?;
407 if facts.len() != vs.len() {
408 bail!(
409 "Evaluating {}: expected {} outputs, got {}",
410 node,
411 facts.len(),
412 vs.len()
413 );
414 }
415 for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
416 if node.outputs[ix].successors.len() == 0 {
417 continue;
418 }
419 if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
420 bail!(
421 "Evaluating {}: output {:?}, expected {:?}, got {:?}",
422 node,
423 ix,
424 f,
425 v
426 );
427 }
428 }
429 }
430
431 self.turn_state.values[node.id] = Some(vs);
432 }
433 self.plan
434 .session_handler
435 .as_ref()
436 .map(|it| it.after_plan_eval(&mut self.turn_state))
437 .transpose()?;
438 }
439 Ok(())
440 }
441
442 pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
443 ensure!(
444 inputs.len() == self.model().inputs.len(),
445 "Wrong number of inputs for model. Expected {} got {}",
446 self.model().inputs.len(),
447 inputs.len()
448 );
449
450 for (ix, t) in inputs.into_iter().enumerate() {
451 self.set_input(ix, t)?
452 }
453 Ok(())
454 }
455
456 fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
457 let expected = expression.eval(&state.resolved_symbols);
458 if let Ok(x) = expected.to_i64()
459 && x != provided
460 {
461 bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
462 }
463 if expected.symbols().len() == 1 {
464 let sym = expected.symbols().into_iter().next().unwrap();
465 if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
466 debug!("Determined symbol {sym}={v}");
467 state.resolved_symbols.set(&sym, v.to_i64().unwrap());
468 }
469 if state.scenario.is_none() {
470 let scope = sym
471 .scope()
472 .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
473 state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
474 }
475 }
476 Ok(())
477 }
478
479 pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
480 let outlet: OutletId = *self
481 .model()
482 .input_outlets()?
483 .get(input)
484 .with_context(|| format!("Invalid input id for model ({input})."))?;
485 if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
486 for (expected, provided) in fact.shape.iter().zip(t.shape()) {
487 Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
488 }
489 }
490 let fact = self.plan.model.outlet_fact(outlet)?;
491 ensure!(
492 fact.matches(&t, Some(&self.turn_state.resolved_symbols))
493 .with_context(|| format!("Setting input {input}"))?,
494 "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
495 );
496 self.ready_turn();
497 self.turn_state.values[outlet.node] = Some(tvec!(t));
498 Ok(())
499 }
500
501 pub fn output(&self, id: usize) -> TractResult<&TValue> {
502 let outlet = self.model().output_outlets()?.get(id).with_context(|| {
503 format!(
504 "Required output {}, only have {}",
505 id,
506 self.model().output_outlets().unwrap().len()
507 )
508 })?;
509 let value: &TValue = self
510 .turn_state
511 .values
512 .get(outlet.node)
513 .context("node id for output beyond node values array")?
514 .as_ref()
515 .context("node is not an output")?
516 .get(outlet.slot)
517 .context("slot id too high")?;
518 Ok(value)
519 }
520
521 pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
522 let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
523 let mut v = tvec![];
524 for o in plan.outputs.iter() {
525 let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
526 format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
527 })?;
528 v.push(vs[o.slot].clone())
529 }
530 Ok(v)
531 }
532
533 pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
534 self.turn_state.values[id] = Some(values);
535 Ok(())
536 }
537
538 pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
539 self.set_values(id, tvec!(value))
540 }
541
542 pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
543 let SimpleState { plan, turn_state, .. } = self;
544 let nodes = plan.model.nodes();
545 let node = &nodes[node];
546 let mut inputs: TVec<TValue> = tvec![];
547 for i in &node.inputs {
548 let prec_node = &nodes[i.node];
549 let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
550 format_err!("Computing {}, precursor {} not done.", node, prec_node)
551 })?;
552 inputs.push(prec[i.slot].clone())
553 }
554 Ok(inputs)
555 }
556
557 pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
558 let inputs = self.prepare_inputs(node)?;
559 self.compute_one_with_inputs(node, inputs)
560 }
561
562 pub fn compute_one_with_inputs(
563 &mut self,
564 node: usize,
565 inputs: TVec<TValue>,
566 ) -> TractResult<()> {
567 let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
568 let nodes = plan.model.nodes();
569 let node = &nodes[node];
570 let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
571 turn_state.values[node.id] = Some(vs);
572 Ok(())
573 }
574
575 pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
576 let values = {
577 #[allow(clippy::needless_collect)] let precs: Vec<usize> =
579 self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
580 for i in precs.into_iter() {
581 if self.turn_state.values[i].is_none() {
582 let _ = self.compute_recursively(i)?;
583 }
584 }
585 let mut inputs: TVec<TValue> = tvec![];
586 {
587 let node = &self.model().nodes()[node];
588 for i in &node.inputs {
589 inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
590 }
591 }
592 let &mut Self {
593 op_states: ref mut states,
594 turn_state: ref mut session_state,
595 ref plan,
596 ..
597 } = self;
598 eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
599 };
600 self.turn_state.values[node] = Some(values);
601 Ok(self.turn_state.values[node].as_ref().unwrap())
602 }
603
604 pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
605 let id = self.model().node_by_name(name)?.id;
606 Self::take(self, id)
607 }
608
609 pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
610 Ok(self.turn_state.values[id]
611 .take()
612 .ok_or_else(|| format_err!("Node is not computed"))?
613 .into_iter()
614 .map(|v| v.into_tensor())
615 .collect())
616 }
617
618 pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
619 &self.plan
620 }
621
622 pub fn model(&self) -> &Graph<F, O> {
623 &self.plan.model
624 }
625
626 pub fn freeze(&self) -> FrozenSimpleState<F, O> {
627 FrozenSimpleState {
628 plan: self.plan.clone(),
629 resolved_symbols: self.turn_state.resolved_symbols.clone(),
630 scenario: self.turn_state.scenario,
631 states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
632 values: self
633 .turn_state
634 .values
635 .iter()
636 .enumerate()
637 .map(|(ix, t)| {
638 if self.model().nodes[ix].op_is::<Const>() {
639 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
640 } else {
641 None
642 }
643 })
644 .collect(),
645 }
646 }
647
648 pub fn freeze_into(self) -> FrozenSimpleState<F, O> {
649 let plan = self.plan;
650 let model = &plan.model;
651 FrozenSimpleState {
652 resolved_symbols: self.turn_state.resolved_symbols,
653 scenario: self.turn_state.scenario,
654 states: self.op_states.into_iter().map(|s| s.map(|s| s.freeze_into())).collect(),
655 values: self
656 .turn_state
657 .values
658 .into_iter()
659 .enumerate()
660 .map(|(ix, t)| {
661 if model.nodes[ix].op_is::<Const>() {
662 t.map(|t| t.into_iter().map(|t| t.into_tensor()).collect())
663 } else {
664 None
665 }
666 })
667 .collect(),
668 plan,
669 }
670 }
671}
672
673pub fn eval<F, O>(
674 session_state: &mut TurnState,
675 mut state: Option<&mut (dyn OpState + 'static)>,
676 node: &Node<F, O>,
677 input: TVec<TValue>,
678) -> TractResult<TVec<TValue>>
679where
680 F: Fact + Clone + 'static,
681 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
682{
683 #[allow(clippy::let_and_return)]
685 let r = match state {
686 Some(ref mut state) => state.eval(session_state, node.op(), input),
687 None => node.op().eval_with_session(node.id, session_state, input),
688 }
689 .with_context(|| format!("Evaluating {node}"));
690 r
692}
693
694#[derive(Clone, Debug)]
695pub struct FrozenSimpleState<F, O>
696where
697 F: Fact + Clone + 'static,
698 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
699{
700 plan: Arc<SimplePlan<F, O>>,
701 pub resolved_symbols: SymbolValues,
702 pub scenario: Option<usize>,
703 pub states: Vec<Option<Box<dyn FrozenOpState>>>,
704 pub values: Vec<Option<TVec<Tensor>>>,
705}
706
707impl<F, O> FrozenSimpleState<F, O>
708where
709 F: Fact + Clone + 'static,
710 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
711{
712 pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
713 &self.plan
714 }
715
716 pub fn unfreeze(&self) -> SimpleState<F, O> {
717 SimpleState {
718 plan: self.plan.clone(),
719 turn_state: TurnState {
720 resolved_symbols: self.resolved_symbols.clone(),
721 scenario: self.scenario,
722 cached_mmm_scratch_space: None.into(),
723 scratch_extensions: anymap3::Map::new(),
724 values: self
725 .values
726 .iter()
727 .map(|t| {
728 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
729 })
730 .collect(),
731 },
732 op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
733 }
734 }
735}
736
737#[cfg(test)]
738mod test {
739 use super::*;
740 fn is_send<T: Send>() {}
741 fn is_sync<T: Sync>() {}
742
743 #[test]
744 fn type_model_is_sync() {
745 is_sync::<TypedModel>();
746 }
747
748 #[test]
749 fn type_model_is_send() {
750 is_send::<TypedModel>();
751 }
752
753 #[test]
754 fn type_plan_is_send() {
755 is_send::<TypedSimplePlan>();
756 }
757
758 #[test]
759 fn type_plan_is_sync() {
760 is_sync::<TypedSimplePlan>();
761 }
762
763 #[test]
764 fn frozen_type_state_is_send() {
765 is_send::<TypedFrozenSimpleState>();
766 }
767}