1use std::borrow::Borrow;
2use std::cell::RefCell;
3use std::fmt::{Debug, Display};
4
5use multithread::Executor;
6
7use crate::internal::*;
8use crate::model::{Fact, Graph, OutletId};
9use crate::ops::FrozenOpState;
10use crate::ops::konst::Const;
11use crate::runtime::RunOptions;
12
13use self::order::{build_flush_list, eval_order_for_nodes, eval_order_opt_ram_for_nodes};
14
15pub struct TurnState {
16 pub resolved_symbols: SymbolValues,
17 pub scenario: Option<usize>,
18 pub cached_mmm_scratch_space: RefCell<Option<Box<dyn tract_linalg::mmm::ScratchSpace>>>,
19 pub scratch_extensions: anymap3::Map,
20 pub values: Vec<Option<TVec<TValue>>>,
21}
22
23impl Default for TurnState {
24 fn default() -> Self {
25 TurnState {
26 resolved_symbols: SymbolValues::default(),
27 scenario: None,
28 cached_mmm_scratch_space: None.into(),
29 scratch_extensions: anymap3::Map::new(),
30 values: vec![],
31 }
32 }
33}
34
35impl Clone for TurnState {
36 fn clone(&self) -> Self {
37 TurnState {
38 resolved_symbols: self.resolved_symbols.clone(),
39 scenario: self.scenario,
40 cached_mmm_scratch_space: None.into(),
41 scratch_extensions: anymap3::Map::new(),
42 values: vec![],
43 }
44 }
45}
46
47pub trait SessionStateHandler: Send + Sync + Debug {
48 fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
49 fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()>;
50}
51
52impl Debug for TurnState {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "SessionState({:?})", self.resolved_symbols)
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct SimplePlan<F, O>
60where
61 F: Fact + Clone + 'static,
62 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
63{
64 pub(crate) model: Arc<Graph<F, O>>,
65 outputs: Vec<OutletId>,
66 order: Vec<usize>,
67 flush_lists: Vec<TVec<usize>>,
68 has_unresolved_symbols: bool,
69 symbols: Vec<Symbol>,
70 executor: Option<Executor>,
71 session_handler: Option<Arc<dyn SessionStateHandler + 'static>>,
72}
73
74impl<F, O> SimplePlan<F, O>
75where
76 F: Fact + Clone + 'static,
77 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
78{
79 pub fn new(model: impl Into<Arc<Graph<F, O>>>) -> TractResult<Arc<SimplePlan<F, O>>> {
81 let model = model.into();
82 Self::build(model, &RunOptions::default()).map(Arc::new)
83 }
84
85 pub fn new_with_options(
87 model: impl Into<Arc<Graph<F, O>>>,
88 options: &RunOptions,
89 ) -> TractResult<Arc<SimplePlan<F, O>>> {
90 let model = model.into();
91 Self::build(model, options).map(Arc::new)
92 }
93
94 #[deprecated]
96 pub fn new_for_output(
97 model: Graph<F, O>,
98 output: OutletId,
99 ) -> TractResult<Arc<SimplePlan<F, O>>> {
100 #[allow(deprecated)]
101 Self::build_with_outputs_and_deps(model, &[output], &[], &RunOptions::default())
102 .map(Arc::new)
103 }
104
105 #[deprecated]
107 pub fn new_for_outputs(
108 model: impl Into<Arc<Graph<F, O>>>,
109 outputs: &[OutletId],
110 ) -> TractResult<Arc<SimplePlan<F, O>>> {
111 #[allow(deprecated)]
112 Self::build_with_outputs_and_deps(model, outputs, &[], &RunOptions::default()).map(Arc::new)
113 }
114
115 pub fn with_session_handler<H: SessionStateHandler + 'static>(
116 mut self,
117 session_handler: H,
118 ) -> Self {
119 self.session_handler = Some(Arc::new(session_handler));
120 self
121 }
122
123 #[deprecated]
124 pub fn new_for_outputs_and_deps(
125 model: impl Into<Arc<Graph<F, O>>>,
126 outputs: &[OutletId],
127 deps: &[(usize, usize)],
128 ) -> TractResult<Arc<SimplePlan<F, O>>> {
129 #[allow(deprecated)]
130 Self::build_with_outputs_and_deps(model, outputs, deps, &RunOptions::default())
131 .map(Arc::new)
132 }
133
134 pub fn build(
135 model: impl Into<Arc<Graph<F, O>>>,
136 options: &RunOptions,
137 ) -> TractResult<SimplePlan<F, O>> {
138 let model = model.into();
139 let outputs = model.outputs.clone();
140 #[allow(deprecated)]
141 Self::build_with_outputs_and_deps(model, &outputs, &[], options)
142 }
143
144 #[deprecated]
145 pub fn build_with_outputs_and_deps(
146 model: impl Into<Arc<Graph<F, O>>>,
147 outputs: &[OutletId],
148 deps: &[(usize, usize)],
149 options: &RunOptions,
150 ) -> TractResult<SimplePlan<F, O>> {
151 let model = model.into();
152 let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
153 let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
154 let mut order = if options.skip_order_opt_ram {
155 eval_order_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
156 } else {
157 eval_order_opt_ram_for_nodes(model.nodes(), &inputs, &outputs_nodes, deps)?
158 };
159 order.retain(|node| !model.node(*node).op_is::<Const>());
160 let flush_lists = build_flush_list(&*model, &order, outputs, |n| !n.op_is::<Const>());
161
162 #[allow(clippy::mutable_key_type)]
163 let mut symbols: std::collections::HashSet<Symbol> = Default::default();
164 for node in &model.nodes {
165 for output in &node.outputs {
166 if let Ok(fact) = output.fact.to_typed_fact() {
167 symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
168 }
169 }
170 }
171 Ok(SimplePlan {
172 model,
173 order,
174 flush_lists,
175 outputs: outputs.to_vec(),
176 has_unresolved_symbols: !symbols.is_empty(),
177 symbols: symbols.into_iter().collect(),
178 executor: options.executor.clone(),
179 session_handler: None,
180 })
181 }
182
183 pub fn order_without_consts(&self) -> &[usize] {
184 &self.order
185 }
186
187 pub fn run(self: &Arc<Self>, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
188 let mut state = self.spawn()?;
189 state.run(inputs)
190 }
191
192 pub fn model(&self) -> &Graph<F, O> {
193 self.model.borrow()
194 }
195
196 pub fn spawn(self: &Arc<Self>) -> TractResult<SimpleState<F, O>> {
197 SimpleState::new(self)
198 }
199}
200
201#[derive(Clone, Debug)]
202pub struct SimpleState<F, O>
203where
204 F: Fact + Clone + 'static,
205 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
206{
207 pub(crate) plan: Arc<SimplePlan<F, O>>,
208 pub op_states: Vec<Option<Box<dyn OpState>>>,
209 pub turn_state: TurnState,
210}
211
212impl<F, O> SimpleState<F, O>
213where
214 F: Fact + Clone + 'static,
215 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
216{
217 pub fn new(plan: &Arc<SimplePlan<F, O>>) -> TractResult<SimpleState<F, O>> {
218 let plan = Arc::clone(plan);
219 let turn = TurnState::default();
220 let model = plan.model();
221 let states: Vec<Option<Box<dyn OpState>>> = vec![None; model.nodes.len()];
222 let mut state = SimpleState { plan, op_states: states, turn_state: turn };
223 state.reset_op_states()?;
224 Ok(state)
225 }
226
227 pub fn new_from_inputs(
228 plan: &Arc<SimplePlan<F, O>>,
229 inputs: TVec<TValue>,
230 ) -> TractResult<SimpleState<F, O>> {
231 let mut state = SimpleState::new(plan)?;
232 state.set_inputs(inputs)?;
233 state.resolve_symbols_with_states()?;
234
235 Ok(state)
236 }
237
238 fn ready_turn(&mut self) {
239 if self.turn_state.values.len() == 0 {
240 self.turn_state.values = vec![None; self.plan.model.nodes().len()];
241 for node in &self.plan.model.nodes {
242 if let Some(k) = node.op_as::<Const>() {
243 self.turn_state.values[node.id] = Some(tvec!(k.val().clone().into_tvalue()));
244 }
245 }
246 }
247 }
248 pub fn reset_turn(&mut self) -> TractResult<()> {
250 self.reset_turn_keep_symbols();
251 self.turn_state.resolved_symbols = SymbolValues::default();
252 Ok(())
253 }
254
255 pub(crate) fn reset_turn_keep_symbols(&mut self) {
260 for node in &self.plan.order {
261 self.turn_state.values[*node] = None;
262 }
263 }
264
265 pub(crate) fn clear_resolved_symbols(&mut self) {
269 self.turn_state.resolved_symbols = SymbolValues::default();
270 self.turn_state.scenario = None;
271 }
272
273 fn reset_op_states(&mut self) -> TractResult<()> {
275 let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
276 for (ix, n) in plan.model.nodes.iter().enumerate() {
277 states[ix] = if n.op().is_stateless() { None } else { n.op().state(turn_state, ix)? };
278 }
279 Ok(())
280 }
281
282 pub(crate) fn resolve_symbols_with_states(&mut self) -> TractResult<()> {
283 for state in self
284 .op_states
285 .iter_mut()
286 .filter_map(Option::as_mut)
287 .filter(|s| s.init_tensor_fact().is_some())
288 {
289 state.resolve_symbols(&mut self.turn_state)?;
290 }
291 Ok(())
292 }
293
294 pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
295 self.run_plan_with_eval(inputs, self::eval)
296 }
297
298 pub fn exec(&mut self) -> TractResult<()> {
299 self.exec_plan_with_eval(self::eval)
300 }
301
302 pub fn run_plan_with_eval<Eval, E>(
303 &mut self,
304 inputs: TVec<TValue>,
305 eval: Eval,
306 ) -> TractResult<TVec<TValue>>
307 where
308 Eval: for<'a, 'b, 'c> FnMut(
309 &'a mut TurnState,
310 Option<&'b mut (dyn OpState + 'static)>,
311 &'c Node<F, O>,
312 TVec<TValue>,
313 ) -> Result<TVec<TValue>, E>,
314 E: Into<anyhow::Error> + Send + Sync + 'static,
315 {
316 self.set_inputs(inputs)?;
317 self.resolve_symbols_with_states()?;
318 self.exec_plan_with_eval(eval)?;
319 let outputs = self.outputs()?;
320 self.reset_turn()?;
321 Ok(outputs)
322 }
323
324 pub fn exec_plan_with_eval<Eval, E>(&mut self, eval: Eval) -> TractResult<()>
325 where
326 Eval: for<'a, 'b, 'c> FnMut(
327 &'a mut TurnState,
328 Option<&'b mut (dyn OpState + 'static)>,
329 &'c Node<F, O>,
330 TVec<TValue>,
331 ) -> Result<TVec<TValue>, E>,
332 E: Into<anyhow::Error> + Send + Sync + 'static,
333 {
334 if let Some(executor) = self.plan().executor.as_ref() {
335 tract_linalg::multithread::multithread_tract_scope(executor.clone(), || {
336 self.do_exec_plan_with_eval(eval)
337 })
338 } else {
339 self.do_exec_plan_with_eval(eval)
340 }
341 }
342
343 fn do_exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
344 where
345 Eval: for<'a, 'b, 'c> FnMut(
346 &'a mut TurnState,
347 Option<&'b mut (dyn OpState + 'static)>,
348 &'c Node<F, O>,
349 TVec<TValue>,
350 ) -> Result<TVec<TValue>, E>,
351 E: Into<anyhow::Error> + Send + Sync + 'static,
352 {
353 {
354 self.ready_turn();
355 self.plan
356 .session_handler
357 .as_ref()
358 .map(|it| it.before_plan_eval(&mut self.turn_state))
359 .transpose()?;
360
361 let mut syms_done = !self.plan.has_unresolved_symbols
362 || self
363 .plan
364 .symbols
365 .iter()
366 .all(|s| self.turn_state.resolved_symbols.get(s).is_some());
367
368 for (step, n) in self.plan.order.iter().enumerate() {
369 let node = self.plan.model.node(*n);
370 trace!("Running step {step}, node {node}");
371 let mut inputs: TVec<TValue> = tvec![];
372 for i in &node.inputs {
373 trace!(" use input {i:?}");
374 let prec_node = self.plan.model.node(i.node);
375 let prec = self.turn_state.values[i.node].as_ref().ok_or_else(|| {
376 format_err!("Computing {}, precursor {} not done:", node, prec_node)
377 })?;
378 inputs.push(prec[i.slot].clone())
379 }
380
381 for flush in &self.plan.flush_lists[step] {
382 trace!(" Ran {} can now flush {}", node, self.plan.model.node(*flush));
383 self.turn_state.values[*flush] = None;
384 }
385
386 if cfg!(debug_assertions) {
387 let facts = self.plan.model.node_input_facts(node.id)?;
388 if facts.len() != inputs.len() {
389 bail!(
390 "Evaluating {}: expected {} inputs, got {}",
391 node,
392 facts.len(),
393 inputs.len()
394 );
395 }
396 for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
397 if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
398 bail!(
399 "Evaluating {}: input {:?}, expected {:?}, got {:?}",
400 node,
401 ix,
402 f,
403 v
404 );
405 }
406 }
407 }
408
409 let vs = eval(
410 &mut self.turn_state,
411 self.op_states[node.id].as_deref_mut(),
412 node,
413 inputs,
414 )
415 .map_err(|e| e.into())?;
416
417 if !syms_done && self.plan.has_unresolved_symbols {
418 for (o, v) in node.outputs.iter().zip(vs.iter()) {
419 if let Ok(f) = o.fact.to_typed_fact() {
420 for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
421 Self::resolve(
422 &mut self.turn_state,
423 dim_abstract,
424 *dim_concrete as i64,
425 )?;
426 }
427 }
428 }
429 if self
430 .plan
431 .symbols
432 .iter()
433 .all(|s| self.turn_state.resolved_symbols.get(s).is_some())
434 {
435 syms_done = true;
436 }
437 }
438 if cfg!(debug_assertions) {
439 let facts = self.plan.model.node_output_facts(node.id)?;
440 if facts.len() != vs.len() {
441 bail!(
442 "Evaluating {}: expected {} outputs, got {}",
443 node,
444 facts.len(),
445 vs.len()
446 );
447 }
448 for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
449 if node.outputs[ix].successors.len() == 0 {
450 continue;
451 }
452 if !f.matches(v, Some(&self.turn_state.resolved_symbols))? {
453 bail!(
454 "Evaluating {}: output {:?}, expected {:?}, got {:?}",
455 node,
456 ix,
457 f,
458 v
459 );
460 }
461 }
462 }
463
464 self.turn_state.values[node.id] = Some(vs);
465 }
466 self.plan
467 .session_handler
468 .as_ref()
469 .map(|it| it.after_plan_eval(&mut self.turn_state))
470 .transpose()?;
471 }
472 Ok(())
473 }
474
475 pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
476 ensure!(
477 inputs.len() == self.model().inputs.len(),
478 "Wrong number of inputs for model. Expected {} got {}",
479 self.model().inputs.len(),
480 inputs.len()
481 );
482
483 for (ix, t) in inputs.into_iter().enumerate() {
484 self.set_input(ix, t)?
485 }
486 Ok(())
487 }
488
489 pub(crate) fn set_inputs_drain(&mut self, inputs: &mut TVec<TValue>) -> TractResult<()> {
493 ensure!(
494 inputs.len() == self.model().inputs.len(),
495 "Wrong number of inputs for model. Expected {} got {}",
496 self.model().inputs.len(),
497 inputs.len()
498 );
499 for (ix, t) in inputs.drain(..).enumerate() {
500 self.set_input(ix, t)?
501 }
502 Ok(())
503 }
504
505 fn resolve(state: &mut TurnState, expression: &TDim, provided: i64) -> TractResult<()> {
506 if let TDim::Sym(sym) = expression
507 && state.resolved_symbols.get(sym).is_none()
508 {
509 state.resolved_symbols.set(sym, provided);
510 if state.scenario.is_none() {
511 let scope = sym.scope().with_context(|| {
512 format!(
513 "Symbol {sym:?} points to an invalid (dead ?) SymbolScope. \
514 Make sure to create symbols using the model-managed SymbolScope."
515 )
516 })?;
517 state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
518 }
519 return Ok(());
520 }
521 let expected = expression.eval(&state.resolved_symbols);
522 if let Some(x) = expected.as_i64()
523 && x != provided
524 {
525 bail!("Clashing resolution for expression. {expression}={x} != {provided}. ({state:?})")
526 }
527 if expected.symbols().len() == 1 {
528 let sym = expected.symbols().into_iter().next().unwrap();
529 if let Some(v) = solve_for(&sym, &expected, &provided.to_dim()) {
530 debug!("Determined symbol {sym}={v}");
531 state.resolved_symbols.set(&sym, v.to_i64().unwrap());
532 }
533 if state.scenario.is_none() {
534 let scope = sym
535 .scope()
536 .with_context(|| format!("Symbol {sym:?} points to an invalid (dead ?) SymbolScope. Make sure to create symbols using the model-managed SymbolScope."))?;
537 state.scenario = scope.guess_scenario(&state.resolved_symbols)?;
538 }
539 }
540 Ok(())
541 }
542
543 pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
544 let outlet: OutletId = *self
545 .model()
546 .input_outlets()?
547 .get(input)
548 .with_context(|| format!("Invalid input id for model ({input})."))?;
549 if let Ok(fact) = self.plan.model.outlet_fact(outlet)?.to_typed_fact() {
550 for (expected, provided) in fact.shape.iter().zip(t.shape()) {
551 Self::resolve(&mut self.turn_state, expected, *provided as i64)?;
552 }
553 }
554 let fact = self.plan.model.outlet_fact(outlet)?;
555 ensure!(
556 fact.matches(&t, Some(&self.turn_state.resolved_symbols))
557 .with_context(|| format!("Setting input {input}"))?,
558 "Input at index {input} has incorrect dtype or shape (got {t:?}, expected to match fact {fact:?})",
559 );
560 self.ready_turn();
561 self.turn_state.values[outlet.node] = Some(tvec!(t));
562 Ok(())
563 }
564
565 pub fn output(&self, id: usize) -> TractResult<&TValue> {
566 let outlet = self.model().output_outlets()?.get(id).with_context(|| {
567 format!(
568 "Required output {}, only have {}",
569 id,
570 self.model().output_outlets().unwrap().len()
571 )
572 })?;
573 let value: &TValue = self
574 .turn_state
575 .values
576 .get(outlet.node)
577 .context("node id for output beyond node values array")?
578 .as_ref()
579 .context("node is not an output")?
580 .get(outlet.slot)
581 .context("slot id too high")?;
582 Ok(value)
583 }
584
585 pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
586 let &mut SimpleState { ref plan, ref mut turn_state, .. } = self;
587 let mut v = tvec![];
588 for o in plan.outputs.iter() {
589 let vs = turn_state.values[o.node].as_mut().ok_or_else(|| {
590 format_err!("Outputs of {:?} are not computed", &plan.model.nodes()[o.node])
591 })?;
592 v.push(vs[o.slot].clone())
593 }
594 Ok(v)
595 }
596
597 pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
598 self.turn_state.values[id] = Some(values);
599 Ok(())
600 }
601
602 pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
603 self.set_values(id, tvec!(value))
604 }
605
606 pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
607 let SimpleState { plan, turn_state, .. } = self;
608 let nodes = plan.model.nodes();
609 let node = &nodes[node];
610 let mut inputs: TVec<TValue> = tvec![];
611 for i in &node.inputs {
612 let prec_node = &nodes[i.node];
613 let prec = turn_state.values[i.node].as_ref().ok_or_else(|| {
614 format_err!("Computing {}, precursor {} not done.", node, prec_node)
615 })?;
616 inputs.push(prec[i.slot].clone())
617 }
618 Ok(inputs)
619 }
620
621 pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
622 let inputs = self.prepare_inputs(node)?;
623 self.compute_one_with_inputs(node, inputs)
624 }
625
626 pub fn compute_one_with_inputs(
627 &mut self,
628 node: usize,
629 inputs: TVec<TValue>,
630 ) -> TractResult<()> {
631 let &mut SimpleState { ref plan, ref mut turn_state, op_states: ref mut states, .. } = self;
632 let nodes = plan.model.nodes();
633 let node = &nodes[node];
634 let vs = eval(turn_state, states[node.id].as_deref_mut(), node, inputs)?;
635 turn_state.values[node.id] = Some(vs);
636 Ok(())
637 }
638
639 pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
640 let values = {
641 #[allow(clippy::needless_collect)] let precs: Vec<usize> =
643 self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
644 for i in precs.into_iter() {
645 if self.turn_state.values[i].is_none() {
646 let _ = self.compute_recursively(i)?;
647 }
648 }
649 let mut inputs: TVec<TValue> = tvec![];
650 {
651 let node = &self.model().nodes()[node];
652 for i in &node.inputs {
653 inputs.push(self.turn_state.values[i.node].as_ref().unwrap()[i.slot].clone())
654 }
655 }
656 let &mut Self {
657 op_states: ref mut states,
658 turn_state: ref mut session_state,
659 ref plan,
660 ..
661 } = self;
662 eval(session_state, states[node].as_deref_mut(), &plan.model().nodes[node], inputs)?
663 };
664 self.turn_state.values[node] = Some(values);
665 Ok(self.turn_state.values[node].as_ref().unwrap())
666 }
667
668 pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
669 let id = self.model().node_by_name(name)?.id;
670 Self::take(self, id)
671 }
672
673 pub fn take(&mut self, id: usize) -> TractResult<TVec<Tensor>> {
674 Ok(self.turn_state.values[id]
675 .take()
676 .ok_or_else(|| format_err!("Node is not computed"))?
677 .into_iter()
678 .map(|v| v.into_tensor())
679 .collect())
680 }
681
682 pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
683 &self.plan
684 }
685
686 pub fn model(&self) -> &Graph<F, O> {
687 &self.plan.model
688 }
689
690 pub fn freeze(&self) -> FrozenSimpleState<F, O> {
691 FrozenSimpleState {
692 plan: self.plan.clone(),
693 resolved_symbols: self.turn_state.resolved_symbols.clone(),
694 scenario: self.turn_state.scenario,
695 states: self.op_states.iter().map(|s| s.as_ref().map(|s| s.freeze())).collect(),
696 values: self
697 .turn_state
698 .values
699 .iter()
700 .enumerate()
701 .map(|(ix, t)| {
702 if self.model().nodes[ix].op_is::<Const>() {
703 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tensor()).collect())
704 } else {
705 None
706 }
707 })
708 .collect(),
709 }
710 }
711
712 pub fn freeze_into(self) -> FrozenSimpleState<F, O> {
713 let plan = self.plan;
714 let model = &plan.model;
715 FrozenSimpleState {
716 resolved_symbols: self.turn_state.resolved_symbols,
717 scenario: self.turn_state.scenario,
718 states: self.op_states.into_iter().map(|s| s.map(|s| s.freeze_into())).collect(),
719 values: self
720 .turn_state
721 .values
722 .into_iter()
723 .enumerate()
724 .map(|(ix, t)| {
725 if model.nodes[ix].op_is::<Const>() {
726 t.map(|t| t.into_iter().map(|t| t.into_tensor()).collect())
727 } else {
728 None
729 }
730 })
731 .collect(),
732 plan,
733 }
734 }
735}
736
737pub fn eval<F, O>(
738 session_state: &mut TurnState,
739 mut state: Option<&mut (dyn OpState + 'static)>,
740 node: &Node<F, O>,
741 input: TVec<TValue>,
742) -> TractResult<TVec<TValue>>
743where
744 F: Fact + Clone + 'static,
745 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
746{
747 #[allow(clippy::let_and_return)]
749 let r = match state {
750 Some(ref mut state) => state.eval(session_state, node.op(), input),
751 None => node.op().eval_with_session(node.id, session_state, input),
752 }
753 .with_context(|| format!("Evaluating {node}"));
754 r
756}
757
758#[derive(Clone, Debug)]
759pub struct FrozenSimpleState<F, O>
760where
761 F: Fact + Clone + 'static,
762 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
763{
764 plan: Arc<SimplePlan<F, O>>,
765 pub resolved_symbols: SymbolValues,
766 pub scenario: Option<usize>,
767 pub states: Vec<Option<Box<dyn FrozenOpState>>>,
768 pub values: Vec<Option<TVec<Tensor>>>,
769}
770
771impl<F, O> FrozenSimpleState<F, O>
772where
773 F: Fact + Clone + 'static,
774 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
775{
776 pub fn plan(&self) -> &Arc<SimplePlan<F, O>> {
777 &self.plan
778 }
779
780 pub fn unfreeze(&self) -> SimpleState<F, O> {
781 SimpleState {
782 plan: self.plan.clone(),
783 turn_state: TurnState {
784 resolved_symbols: self.resolved_symbols.clone(),
785 scenario: self.scenario,
786 cached_mmm_scratch_space: None.into(),
787 scratch_extensions: anymap3::Map::new(),
788 values: self
789 .values
790 .iter()
791 .map(|t| {
792 t.as_ref().map(|t| t.iter().map(|t| t.clone().into_tvalue()).collect())
793 })
794 .collect(),
795 },
796 op_states: self.states.iter().map(|s| s.as_ref().map(|s| s.unfreeze())).collect(),
797 }
798 }
799}
800
801#[cfg(test)]
802mod test {
803 use super::*;
804 fn is_send<T: Send>() {}
805 fn is_sync<T: Sync>() {}
806
807 #[test]
808 fn type_model_is_sync() {
809 is_sync::<TypedModel>();
810 }
811
812 #[test]
813 fn type_model_is_send() {
814 is_send::<TypedModel>();
815 }
816
817 #[test]
818 fn type_plan_is_send() {
819 is_send::<TypedSimplePlan>();
820 }
821
822 #[test]
823 fn type_plan_is_sync() {
824 is_sync::<TypedSimplePlan>();
825 }
826
827 #[test]
828 fn frozen_type_state_is_send() {
829 is_send::<TypedFrozenSimpleState>();
830 }
831}