1use tract_core::internal::*;
2use tract_core::{downcast_rs, dyn_clone};
3
4pub trait Model:
6 downcast_rs::Downcast + std::fmt::Debug + dyn_clone::DynClone + Send + Sync
7{
8 fn node_id_by_name(&self, name: &str) -> TractResult<usize>;
10
11 fn node_name(&self, id: usize) -> &str;
13
14 fn node_op(&self, id: usize) -> &dyn Op;
16
17 fn node_const(&self, id: usize) -> bool;
19
20 fn node_op_name(&self, id: usize) -> StaticName;
22
23 fn node_inputs(&self, id: usize) -> &[OutletId];
25
26 fn node_output_count(&self, id: usize) -> usize;
28
29 fn nodes_len(&self) -> usize;
31
32 fn node_display(&self, id: usize) -> String;
34
35 fn node_debug(&self, id: usize) -> String;
37
38 fn eval_order(&self) -> TractResult<Vec<usize>>;
40
41 fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>>;
43
44 fn input_outlets(&self) -> &[OutletId];
46
47 fn set_input_names(&mut self, names: &[&str]) -> TractResult<()>;
48 fn set_output_names(&mut self, names: &[&str]) -> TractResult<()>;
49
50 fn output_outlets(&self) -> &[OutletId];
52
53 fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact>;
55
56 fn outlet_fact_format(&self, outlet: OutletId) -> String;
58
59 fn outlet_label(&self, id: OutletId) -> Option<&str>;
61
62 fn outlet_successors(&self, outlet: OutletId) -> &[InletId];
64
65 fn nested_models(&self, id: usize) -> Vec<(String, &dyn Model)> {
67 if let Some(submodel) =
68 self.node_op(id).downcast_ref::<tract_core::ops::submodel::SubmodelOp>()
69 {
70 return vec![("submodel".into(), submodel.model())];
71 }
72 if let Some(lir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::OptScan>() {
73 return vec![("loop".into(), lir.plan.model())];
74 }
75 if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::Scan>() {
76 return vec![("loop".into(), &mir.body)];
77 }
78 if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::logic::IfThenElse>() {
79 return vec![("then".into(), &mir.then_body), ("else".into(), &mir.else_body)];
80 }
81 #[cfg(feature = "hir")]
82 if let Some(hir) = self.node_op(id).downcast_ref::<tract_hir::ops::scan::InferenceScan>() {
83 return vec![("loop".into(), &hir.body)];
84 }
85 #[cfg(feature = "onnx")]
86 if let Some(hir) = self.node_op(id).downcast_ref::<tract_onnx::ops::logic::If>() {
87 return vec![("then".into(), &hir.then_body), ("else".into(), &hir.else_body)];
88 }
89 vec![]
90 }
91
92 fn nested_models_iters(&self, id: usize, input: &[&TypedFact]) -> Option<TDim> {
94 if let Some(submodel) =
95 self.node_op(id).downcast_ref::<tract_core::ops::submodel::SubmodelOp>()
96 {
97 submodel.iteration_count(input)
98 } else if let Some(lir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::OptScan>()
99 {
100 lir.iteration_count(input)
101 } else if let Some(mir) = self.node_op(id).downcast_ref::<tract_core::ops::scan::Scan>() {
102 mir.iteration_count(input)
103 } else {
104 None
105 }
106 }
107
108 fn auto_outputs(&mut self) -> TractResult<()>;
109
110 fn properties(&self) -> &HashMap<String, Arc<Tensor>>;
111
112 fn symbols(&self) -> &SymbolScope;
113
114 fn get_or_intern_symbol(&self, name: &str) -> Symbol;
115
116 fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()>;
117}
118
119downcast_rs::impl_downcast!(Model);
120dyn_clone::clone_trait_object!(Model);
121
122impl<F, O> Model for Graph<F, O>
123where
124 F: Fact + Hash + Clone + 'static,
125 O: std::fmt::Debug
126 + std::fmt::Display
127 + AsRef<dyn Op>
128 + AsMut<dyn Op>
129 + Clone
130 + 'static
131 + Send
132 + Sync,
133 Graph<F, O>: Send + Sync + 'static,
134{
135 fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
136 self.nodes
137 .iter()
138 .find(|n| n.name == name)
139 .map(|n| n.id)
140 .with_context(|| format!("No node found for name: \"{name}\""))
141 }
142
143 fn node_name(&self, id: usize) -> &str {
144 &self.nodes[id].name
145 }
146
147 fn node_op_name(&self, id: usize) -> StaticName {
148 self.node(id).op().name()
149 }
150
151 fn node_const(&self, id: usize) -> bool {
152 self.node_op_name(id) == "Const"
153 }
154
155 fn node_inputs(&self, id: usize) -> &[OutletId] {
156 &self.nodes[id].inputs
157 }
158
159 fn node_output_count(&self, id: usize) -> usize {
160 self.nodes[id].outputs.len()
161 }
162
163 fn nodes_len(&self) -> usize {
164 self.nodes.len()
165 }
166
167 fn node_display(&self, id: usize) -> String {
168 format!("{}", self.nodes[id])
169 }
170
171 fn node_debug(&self, id: usize) -> String {
172 format!("{:?}", self.nodes[id])
173 }
174
175 fn eval_order(&self) -> TractResult<Vec<usize>> {
176 tract_core::model::order::eval_order(self)
177 }
178
179 fn eval_order_opt_ram(&self) -> TractResult<Vec<usize>> {
180 tract_core::model::order::eval_order_opt_ram(self)
181 }
182
183 fn input_outlets(&self) -> &[OutletId] {
184 &self.inputs
185 }
186
187 fn set_input_names(&mut self, names: &[&str]) -> TractResult<()> {
188 self.set_input_names(names.iter())
189 }
190
191 fn set_output_names(&mut self, names: &[&str]) -> TractResult<()> {
192 self.set_output_names(names)
193 }
194
195 fn output_outlets(&self) -> &[OutletId] {
196 &self.outputs
197 }
198
199 fn node_op(&self, id: usize) -> &dyn Op {
200 self.nodes[id].op.as_ref()
201 }
202
203 fn outlet_typedfact(&self, outlet: OutletId) -> TractResult<TypedFact> {
204 Ok(self.outlet_fact(outlet)?.to_typed_fact()?.into_owned())
205 }
206
207 fn outlet_fact_format(&self, outlet: OutletId) -> String {
208 format!("{:?}", self.outlet_fact(outlet).unwrap())
209 }
210
211 fn outlet_label(&self, id: OutletId) -> Option<&str> {
212 self.outlet_label(id)
213 }
214
215 fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
216 &self.nodes[outlet.node].outputs[outlet.slot].successors
217 }
218
219 fn auto_outputs(&mut self) -> TractResult<()> {
220 self.auto_outputs()
221 }
222
223 fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
224 &self.properties
225 }
226
227 fn symbols(&self) -> &SymbolScope {
228 &self.symbols
229 }
230 fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
231 self.rename_node(id, name)
232 }
233
234 fn get_or_intern_symbol(&self, name: &str) -> Symbol {
235 self.symbols.sym(name)
236 }
237}