use tract_core::internal::*;
use tract_core::{downcast_rs, dyn_clone};
pub trait Model:
downcast_rs::Downcast + std::fmt::Debug + dyn_clone::DynClone + Send + Sync
{
fn node_id_by_name(&self, name: &str) -> TractResult<usize>;
fn node_name(&self, id: usize) -> &str;
fn node_op(&self, id: usize) -> &dyn Op;
fn node_const(&self, id: usize) -> bool;
fn node_op_name(&self, id: usize) -> Cow<str>;
fn node_inputs(&self, id: usize) -> &[OutletId];
fn node_output_count(&self, id: usize) -> usize;
fn nodes_len(&self) -> usize;
fn node_display(&self, id: usize) -> String;
fn node_debug(&self, id: usize) -> String;
fn eval_order(&self) -> TractResult<Vec<usize>>;
fn eval_order_for_io(&self, inputs: &[usize], outputs: &[usize]) -> TractResult<Vec<usize>>;
fn input_outlets(&self) -> &[OutletId];
fn set_input_names(&mut self, names: &[&str]) -> TractResult<()>;
fn set_output_names(&mut self, names: &[&str]) -> TractResult<()>;
fn output_outlets(&self) -> &[OutletId];
fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact>;
fn outlet_fact_format(&self, outlet: OutletId) -> String;
fn outlet_label(&self, id: OutletId) -> Option<&str>;
fn outlet_successors(&self, outlet: OutletId) -> &[InletId];
fn nested_models(&self, id: usize) -> Option<(String, &dyn Model)> {
if let Some(submodel) = self.node_op(id).downcast_ref::<tract_core::ops::submodel::SubmodelOp>() {
return Some(("submodel".into(), submodel.model()));
}
if let Some(lir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::LirScan>() {
return Some(("loop".into(), lir.plan.model()));
}
if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::Scan>() {
return Some(("loop".into(), &mir.body));
}
#[cfg(feature = "hir")]
if let Some(hir) = self.node_op(id).downcast_ref::<tract_hir::ops::scan::InferenceScan>() {
return Some(("loop".into(), &hir.body));
}
None
}
fn nested_models_iters(&self, id: usize, input: &[&TypedFact]) -> Option<TDim> {
if let Some(submodel) = self.node_op(id).downcast_ref::<tract_core::ops::submodel::SubmodelOp>() {
submodel.iteration_count(input)
} else if let Some(lir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::LirScan>() {
lir.iteration_count(input)
} else if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::Scan>() {
mir.iteration_count(input)
} else {
None
}
}
fn auto_outputs(&mut self) -> TractResult<()>;
fn properties(&self) -> &HashMap<String, Arc<Tensor>>;
fn get_or_intern_symbol(&self, name: &str) -> Symbol;
fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()>;
}
downcast_rs::impl_downcast!(Model);
dyn_clone::clone_trait_object!(Model);
impl<F, O> Model for Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: std::fmt::Debug
+ std::fmt::Display
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ 'static
+ Send
+ Sync,
Graph<F, O>: Send + Sync + 'static,
{
fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{name}\""))
}
fn node_name(&self, id: usize) -> &str {
&self.nodes[id].name
}
fn node_op_name(&self, id: usize) -> Cow<str> {
self.node(id).op().name()
}
fn node_const(&self, id: usize) -> bool {
self.node_op_name(id) == "Const"
}
fn node_inputs(&self, id: usize) -> &[OutletId] {
&self.nodes[id].inputs
}
fn node_output_count(&self, id: usize) -> usize {
self.nodes[id].outputs.len()
}
fn nodes_len(&self) -> usize {
self.nodes.len()
}
fn node_display(&self, id: usize) -> String {
format!("{}", self.nodes[id])
}
fn node_debug(&self, id: usize) -> String {
format!("{:?}", self.nodes[id])
}
fn eval_order(&self) -> TractResult<Vec<usize>> {
crate::model::eval_order(self)
}
fn eval_order_for_io(&self, inputs: &[usize], outputs: &[usize]) -> TractResult<Vec<usize>> {
crate::model::order::eval_order_for_nodes(&self.nodes, inputs, outputs, &[])
}
fn input_outlets(&self) -> &[OutletId] {
&self.inputs
}
fn set_input_names(&mut self, names: &[&str]) -> TractResult<()> {
self.set_input_names(names.iter())
}
fn set_output_names(&mut self, names: &[&str]) -> TractResult<()> {
self.set_output_names(names)
}
fn output_outlets(&self) -> &[OutletId] {
&self.outputs
}
fn node_op(&self, id: usize) -> &dyn Op {
self.nodes[id].op.as_ref()
}
fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact> {
Ok(self.outlet_fact(outlet)?.to_typed_fact()?.into_owned())
}
fn outlet_fact_format(&self, outlet: OutletId) -> String {
format!("{:?}", self.outlet_fact(outlet).unwrap())
}
fn outlet_label(&self, id: OutletId) -> Option<&str> {
self.outlet_label(id)
}
fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
fn auto_outputs(&mut self) -> TractResult<()> {
self.auto_outputs()
}
fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
&self.properties
}
fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.rename_node(id, name)
}
fn get_or_intern_symbol(&self, name: &str) -> Symbol {
self.symbol_table.sym(name)
}
}