Struct tract_core::model::Graph
source · pub struct Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,{
pub nodes: Vec<Node<F, O>>,
pub inputs: Vec<OutletId>,
pub outputs: Vec<OutletId>,
pub outlet_labels: HashMap<OutletId, String>,
pub properties: HashMap<String, Arc<Tensor>>,
pub symbol_table: SymbolTable,
}Expand description
Main model class
Parameterized by a Fact class.
Fields§
§nodes: Vec<Node<F, O>>all nodes in the model
inputs: Vec<OutletId>model inputs
outputs: Vec<OutletId>model outputs
outlet_labels: HashMap<OutletId, String>outlet labels
properties: HashMap<String, Arc<Tensor>>model properties
symbol_table: SymbolTablesymbol table
Implementations§
source§impl<F, O> Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
Graph<F, O>: SpecialOps<F, O>,
impl<F, O> Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
Graph<F, O>: SpecialOps<F, O>,
sourcepub fn add_source(
&mut self,
name: impl Into<String>,
fact: F
) -> TractResult<OutletId>
pub fn add_source(
&mut self,
name: impl Into<String>,
fact: F
) -> TractResult<OutletId>
Examples found in repository?
More examples
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut model = TypedModel::default();
let mut wires: TVec<OutletId> = inputs
.iter()
.enumerate()
.map(|(ix, v)| {
model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
})
.collect::<TractResult<_>>()?;
let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
let wire = unsafe {
if self.q_params.is_some() {
let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
op_ref.wire_as_quant_im2col(
&mut model,
"im2col-adhoc",
inputs[0].datum_type(),
&wires,
)?
} else {
self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
}
};
model.set_output_outlets(&[wire])?;
model.into_runnable()?.run(inputs)
}83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
fn translate_node(
&self,
source: &Graph<TI1, O1>,
node: &Node<TI1, O1>,
target: &mut Graph<TI2, O2>,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
let node_is_input =
(0..node.outputs.len()).all(|o| source.inputs.contains(&(node.id, o).into()));
if node_is_input {
(0..node.outputs.len())
.map(|i| {
target.add_source(
if node.outputs.len() > 1 {
format!("{}-{}", node.name, i)
} else {
node.name.to_string()
},
TI2::try_from(&node.outputs[i].fact)?,
)
})
.collect()
} else {
let new_op = O2::try_from(&node.op)?;
let facts = node
.outputs
.iter()
.map(|of| Ok(TI2::try_from(&of.fact)?))
.collect::<TractResult<TVec<_>>>()?;
let new_id = target.add_node(node.name.clone(), new_op, facts)?;
for (ix, o) in node.inputs.iter().enumerate() {
target.add_edge(mapping[o], InletId::new(new_id, ix))?
}
Ok(node.outputs.iter().enumerate().map(|(ix, _)| OutletId::new(new_id, ix)).collect())
}
}293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
fn declutter_pull_batcheable_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_input, input) in self.input_mapping.iter().enumerate() {
if let Some(info) = input.as_scan() {
let scan_source = self.body.input_outlets()?[model_input];
let scan_source_node = self.body.node(scan_source.node);
for successor in &scan_source_node.outputs[0].successors {
let successor_node = self.body.node(successor.node);
if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
continue;
}
let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
let mut outside_patch = TypedModelPatch::new(format!(
"Outer patch for input extraction of {}",
successor_node
));
let mut patch_inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let input = patch_inputs[info.slot];
let new_input_wire = outside_patch.wire_node(
format!("{}.extracted.{}", node.name, successor_node.name),
successor_node.op.clone(),
&[input],
)?[0];
patch_inputs.push(new_input_wire);
let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
let mut new_input_inner_fact = new_input_outer_fact.clone();
new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());
let mut new_body = self.body.clone();
let new_source_wire = new_body.add_source(
format!("{}.extracted.{}", node.name, successor_node.name),
new_input_inner_fact,
)?;
let mut inner_patch = TypedModelPatch::new(format!(
"Inner body patch for extraction of {}",
successor_node
));
let new_source_wire_in_patch =
inner_patch.tap_model(&new_body, new_source_wire)?;
inner_patch
.shunt_outside(
&new_body,
OutletId::new(successor.node, 0),
new_source_wire_in_patch,
)
.with_context(|| "patching inner model")?;
inner_patch.apply(&mut new_body)?;
let mut input_mapping = self.input_mapping.clone();
input_mapping.push(InputMapping::Scan(ScanInfo {
axis: axis_after,
chunk: info.chunk,
slot: node.inputs.len(),
}));
let new_op = Self {
input_mapping,
output_mapping: self.output_mapping.clone(),
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let output_wires =
outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
for w in output_wires {
outside_patch
.shunt_outside(model, OutletId::new(node.id, w.slot), w)
.with_context(|| "patching outer model")?;
}
return Ok(Some(outside_patch));
}
}
}
}
Ok(None)
}146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
fn declutter(
&self,
model: &TypedModel,
dequant: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut current = dequant;
let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
while let Some(quant) = model.single_succ(current.id)? {
let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
} else {
op.0.downcast_ref::<QuantizeLinearI8>()
.map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
}
} else {
None
};
if let Some((scale, zero_point, dt)) = q_params {
// first, try Op::quantize() on all ops in the chain
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
let mut next = model.single_succ(dequant.id)?.unwrap();
loop {
if let Some(op) = next
.op
.quantize(model, dequant, dt, scale, zero_point)
.with_context(|| format!("Quantizing {}", next))?
{
wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
} else {
break;
}
if next.id == current.id {
patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
return Ok(Some(patch));
} else {
next = model.single_succ(next.id)?.unwrap();
}
}
// or else make a lookup table
if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
let mut adhoc_model = TypedModel::default();
let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
let mut next = model.single_succ(dequant.id)?.unwrap();
let mut name = None;
// plug in dequant
wire = adhoc_model.wire_node(
&*dequant.name,
dequant.op.clone(),
[wire].as_ref(),
)?[0];
while next.id != quant.id {
name.get_or_insert(&*next.name);
wire =
adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
[0];
next = model.single_succ(next.id)?.unwrap();
}
// plug in quant
wire =
adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
adhoc_model.set_output_outlets(&[wire])?;
let input = (0u8..=255).collect::<Vec<u8>>();
let input = match dt {
DatumType::I8 => unsafe {
tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
},
DatumType::U8 => tensor1(&input),
_ => unreachable!(),
};
let output =
SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
let table: &[u8] = match dt {
DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
DatumType::U8 => output.as_slice::<u8>()?,
_ => unreachable!(),
};
let op = lookup_table((tract_linalg::ops().lut_u8)(table));
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
return Ok(Some(patch));
}
}
let (input_facts, output_facts) = model.node_facts(quant.id)?;
let invariants = quant
.op
.invariants(&input_facts, &output_facts)
.with_context(|| format!("Querying invariants for {}", quant))?;
if invariants.element_wise() {
current = quant;
} else {
break;
}
}
Ok(None)
}source§impl<F, O> Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
impl<F, O> Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
sourcepub fn add_node(
&mut self,
name: impl Into<String>,
op: impl Into<O>,
output_facts: TVec<F>
) -> TractResult<usize>
pub fn add_node(
&mut self,
name: impl Into<String>,
op: impl Into<O>,
output_facts: TVec<F>
) -> TractResult<usize>
Examples found in repository?
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
pub fn add_source(&mut self, name: impl Into<String>, fact: F) -> TractResult<OutletId> {
let source = self.create_source(fact.clone());
let id = self.add_node(name, source, tvec!(fact))?;
let id = OutletId::new(id, 0);
self.inputs.push(id);
Ok(id)
}
}
impl<F, O> Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
pub fn add_node(
&mut self,
name: impl Into<String>,
op: impl Into<O>,
output_facts: TVec<F>,
) -> TractResult<usize> {
let op = op.into();
let name = name.into();
let id = self.nodes.len();
let outputs =
output_facts.into_iter().map(|fact| Outlet { fact, successors: tvec!() }).collect();
let node = Node { id, name, op, inputs: vec![], outputs };
self.nodes.push(node);
Ok(id)
}
/// Connect a node outlet to a node inlet.
pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
self.nodes[previous.node].outputs[previous.slot]
.successors
.retain(|&mut succ| succ != inlet);
}
{
let prec = &mut self.nodes[outlet.node];
prec.outputs[outlet.slot].successors.push(inlet);
}
let succ = &mut self.nodes[inlet.node];
#[allow(clippy::comparison_chain)]
if inlet.slot == succ.inputs.len() {
succ.inputs.push(outlet);
} else if inlet.slot < succ.inputs.len() {
succ.inputs[inlet.slot] = outlet;
} else {
bail!("Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ", inlet.slot, succ)
}
Ok(())
}
// Inputs
/// Get model inputs.
pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.inputs)
}
/// Change model inputs.
pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
self.inputs = inputs.to_vec();
Ok(())
}
/// Change model inputs and return `self`.
pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
self.set_input_outlets(inputs)?;
Ok(self)
}
/// Set model inputs by the node name.
pub fn set_input_names(
&mut self,
inputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut ids = vec![];
for i in inputs.into_iter() {
let node = self.node_by_name(&i)?;
for o in 0..node.outputs.len() {
ids.push(OutletId::new(node.id, o))
}
}
self.inputs = ids;
Ok(())
}
/// Set model inputs by the node name and return `self`.
pub fn with_input_names(
mut self,
inputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_input_names(inputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
let input = self.input_outlets()?[ix];
self.outlet_fact(input)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let input = self.input_outlets()?[ix];
self.outlet_fact_mut(input)
}
/// Set the `ix`-th input tensor type information.
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
let outlet = self.inputs[input];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th input tensor type information and return `self`.
pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
self.set_input_fact(input, fact)?;
Ok(self)
}
// Outputs
/// Get model outputs.
pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.outputs)
}
/// Guess outputs from the topology: node or nodes with no successors.
pub fn auto_outputs(&mut self) -> TractResult<()> {
let outputs = self
.nodes
.iter()
.flat_map(|n| {
let id = n.id;
n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
(OutletId::new(id, ix), output_fact.successors.len())
})
})
.filter(|(_f, succs)| *succs == 0)
.map(|(f, _)| f)
.collect();
self.outputs = outputs;
Ok(())
}
/// Change model outputs.
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
self.outputs = outputs.to_vec();
Ok(())
}
/// Change model outputs and return `self`.
pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
self.set_output_outlets(outputs)?;
Ok(self)
}
/// Set model outputs by node names.
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut labels: HashMap<Cow<str>, OutletId> =
self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
for n in self.nodes() {
for ix in 0..n.outputs.len() {
labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
}
}
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| {
let s = s.as_ref();
labels
.get(s)
.cloned()
.or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
.ok_or_else(|| format_err!("Node {} not found", s))
})
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
/// Set model outputs by node names and return `self`.
pub fn with_output_names(
mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_output_names(outputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
/// Set the `ix`-th output tensor type information.
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th output tensor type information and return `self`.
pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
self.set_output_fact(output, fact)?;
Ok(self)
}
// nodes and their facts
/// Iterate over all node names.
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{}\"", name))
}
/// Find a node by its name.
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
/// Borrow mutably a node by its name.
pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&mut self.nodes[id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
Ok(())
}
/// Find a node by its id.
pub fn node(&self, id: usize) -> &Node<F, O> {
&self.nodes[id]
}
/// Find a node by its id.
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
&mut self.nodes[id]
}
/// Access the nodes table.
pub fn nodes(&self) -> &[Node<F, O>] {
&self.nodes
}
/// Access the nodes table.
pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
&mut self.nodes
}
/// Get input and output tensor information for a node.
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}
// outlet labels
/// Get label for an outlet.
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
/// Set label for an outlet.
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
self.outlet_labels.insert(outlet, label);
Ok(())
}
/// Set label for an outlet and return `self`.
pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
self.set_outlet_label(outlet, label)?;
Ok(self)
}
/// Find outlet by label.
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
}
// misc
/// Computes an evalutation order for the graph inputs and outputs
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
Ok(())
}
/// Performs a sanity check on network connections.
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}
impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
O: fmt::Debug
+ fmt::Display
+ From<crate::ops::konst::Const>
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ Hash
+ 'static,
{
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor,
) -> TractResult<OutletId> {
let v = v.into_arc_tensor();
let fact = F::from(v.clone());
let name = name.into();
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
}More examples
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
pub fn fuse_with_next<IO: Into<O>>(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
new_op: IO,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
succ
} else {
bail!("Non single successor fuse attempt")
};
let new_op = new_op.into();
let by = patch.add_node(&*node.name, new_op, tvec!(succ.outputs[0].fact.clone()))?;
for (ix, i) in node.inputs.iter().enumerate() {
let o = patch.tap_model(patched_model, *i)?;
patch.add_edge(o, InletId::new(by, ix))?;
}
for ix in 0..node.outputs.len() {
patch.shunt_outside(
patched_model,
OutletId::new(succ.id, ix),
OutletId::new(by, ix),
)?;
}
Ok(patch)
}
/// Convenience method creating a patch that shunts the given node.
pub fn shunt_one_op(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
) -> TractResult<ModelPatch<F, O>> {
Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
}
#[allow(clippy::type_complexity)]
pub fn rewire(
patched_model: &Graph<F, O>,
from: &[OutletId],
to: &[OutletId],
wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let taps = from
.iter()
.map(|f| patch.tap_model(patched_model, *f))
.collect::<TractResult<TVec<_>>>()?;
let news = wiring(&mut patch, &taps)?;
if news.len() != to.len() {
bail!(
"Wrong number of outputs for rewiring, expected {}, function returned {}",
to.len(),
news.len()
);
}
for (new, &old) in izip!(news, to) {
patch.shunt_outside(patched_model, old, new)?;
}
Ok(patch)
}
/// Convenience method creating a patch that replace a single unary operation.
pub fn single_unary_op<IO: Into<O>>(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
new_op: IO,
) -> TractResult<ModelPatch<F, O>> {
Self::replace_single_op(patched_model, node, &[node.inputs[0]], new_op)
}
/// Convenience method creating a patch that insert an unary op on an outlet.
pub fn intercept<IO: Into<O>>(
patched_model: &Graph<F, O>,
outlet: OutletId,
name: impl Into<String>,
new_op: IO,
fact: F,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let tap = patch.tap_model(patched_model, outlet)?;
let new_id = patch.add_node(name, new_op, tvec!(fact))?;
patch.add_edge(tap, InletId::new(new_id, 0))?;
patch.shunt_outside(patched_model, outlet, OutletId::new(new_id, 0))?;
Ok(patch)
}
/// Apply all changes in the patch to the target model.
pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
let prior_target_inputs = target.input_outlets()?.len();
let prior_target_outputs = target.output_outlets()?.len();
let ModelPatch {
model: patch,
incoming: mut mapping,
shunt_outlet_by,
obliterate,
inputs: replaced_inputs,
..
} = self;
let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
let mut model_input_outlets = target.input_outlets()?.to_vec();
for node in patch.nodes {
if <Graph<F, O>>::is_source(&node.op)
&& mapping.contains_key(&OutletId::new(node.id, 0))
{
// this is a tap
continue;
}
let Node { id: patch_node_id, name, inputs, op, outputs } = node;
let n_outputs = outputs.len();
for dup in 0..target.nodes.len() {
if target.node(dup).op().same_as(op.as_ref())
&& inputs.len() == target.node(dup).inputs.len()
&& inputs
.iter()
.zip(target.node(dup).inputs.iter())
.all(|(patch_input, d)| mapping[patch_input] == *d)
{
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
}
continue;
}
}
let facts = outputs.into_iter().map(|of| of.fact).collect();
let added_node_id = target.add_node(name, op, facts)?;
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
}
all_inputs.insert(added_node_id, inputs);
if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
// this is actually an input replacement
model_input_outlets.iter_mut().for_each(|oo| {
if oo.node == replaced_inputs[&patch_node_id] {
oo.node = added_node_id;
}
});
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (outlet, by) in shunt_outlet_by {
let replace_by = mapping[&by];
let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
for succ in succs {
target.add_edge(replace_by, succ)?;
}
for o in target.outputs.iter_mut() {
if *o == outlet {
*o = replace_by;
}
}
if let Some(label) = target.outlet_labels.remove(&outlet) {
target.set_outlet_label(replace_by, label)?;
}
}
if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
bail!("Duplicate usage of node as output");
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (node, inputs) in all_inputs {
for (ix, input) in inputs.into_iter().enumerate() {
target.add_edge(mapping[&input], InletId::new(node, ix))?;
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for node in obliterate {
target.node_mut(node).op = target.create_dummy();
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
target.set_input_outlets(&model_input_outlets)?;
Ok(())
}83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
fn translate_node(
&self,
source: &Graph<TI1, O1>,
node: &Node<TI1, O1>,
target: &mut Graph<TI2, O2>,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
let node_is_input =
(0..node.outputs.len()).all(|o| source.inputs.contains(&(node.id, o).into()));
if node_is_input {
(0..node.outputs.len())
.map(|i| {
target.add_source(
if node.outputs.len() > 1 {
format!("{}-{}", node.name, i)
} else {
node.name.to_string()
},
TI2::try_from(&node.outputs[i].fact)?,
)
})
.collect()
} else {
let new_op = O2::try_from(&node.op)?;
let facts = node
.outputs
.iter()
.map(|of| Ok(TI2::try_from(&of.fact)?))
.collect::<TractResult<TVec<_>>>()?;
let new_id = target.add_node(node.name.clone(), new_op, facts)?;
for (ix, o) in node.inputs.iter().enumerate() {
target.add_edge(mapping[o], InletId::new(new_id, ix))?
}
Ok(node.outputs.iter().enumerate().map(|(ix, _)| OutletId::new(new_id, ix)).collect())
}
}39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
fn wire_node(
&mut self,
name: impl Into<String>,
op: impl Into<Box<dyn TypedOp>>,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let op = op.into();
let name = name.into();
{
let output_facts = || -> TractResult<TVec<TypedFact>> {
let input_facts = inputs
.iter()
.map(|o| self.outlet_fact(*o))
.collect::<TractResult<TVec<_>>>()?;
let facts = op.output_facts(&input_facts).context("in output_facts invocation")?;
if input_facts.iter().all(|f| f.konst.is_some()) && op.is_stateless() {
let tensors = input_facts
.iter()
.map(|f| f.konst.clone().unwrap().into_tvalue())
.collect::<TVec<_>>();
if let Ok(outputs) = op.eval(tensors) {
return Ok(outputs.into_iter().map(|t| TypedFact::from(&*t)).collect());
}
}
Ok(facts)
};
let output_facts = output_facts()
.with_context(|| format!("wiring {} ({:?}), determining output_facts", name, op))?;
let id = self.add_node(&name, &op, output_facts)?;
inputs
.iter()
.enumerate()
.try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
TractResult::Ok(
self.node(id)
.outputs
.iter()
.enumerate()
.map(|(ix, _)| OutletId::new(id, ix))
.collect(),
)
}
.with_context(|| format!("Wiring node \"{}\", {:?}", name, op))
}sourcepub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()>
pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()>
Connect a node outlet to a node inlet.
Examples found in repository?
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
pub fn fuse_with_next<IO: Into<O>>(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
new_op: IO,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
succ
} else {
bail!("Non single successor fuse attempt")
};
let new_op = new_op.into();
let by = patch.add_node(&*node.name, new_op, tvec!(succ.outputs[0].fact.clone()))?;
for (ix, i) in node.inputs.iter().enumerate() {
let o = patch.tap_model(patched_model, *i)?;
patch.add_edge(o, InletId::new(by, ix))?;
}
for ix in 0..node.outputs.len() {
patch.shunt_outside(
patched_model,
OutletId::new(succ.id, ix),
OutletId::new(by, ix),
)?;
}
Ok(patch)
}
/// Convenience method creating a patch that shunts the given node.
pub fn shunt_one_op(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
) -> TractResult<ModelPatch<F, O>> {
Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
}
#[allow(clippy::type_complexity)]
pub fn rewire(
patched_model: &Graph<F, O>,
from: &[OutletId],
to: &[OutletId],
wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let taps = from
.iter()
.map(|f| patch.tap_model(patched_model, *f))
.collect::<TractResult<TVec<_>>>()?;
let news = wiring(&mut patch, &taps)?;
if news.len() != to.len() {
bail!(
"Wrong number of outputs for rewiring, expected {}, function returned {}",
to.len(),
news.len()
);
}
for (new, &old) in izip!(news, to) {
patch.shunt_outside(patched_model, old, new)?;
}
Ok(patch)
}
/// Convenience method creating a patch that replace a single unary operation.
pub fn single_unary_op<IO: Into<O>>(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
new_op: IO,
) -> TractResult<ModelPatch<F, O>> {
Self::replace_single_op(patched_model, node, &[node.inputs[0]], new_op)
}
/// Convenience method creating a patch that insert an unary op on an outlet.
pub fn intercept<IO: Into<O>>(
patched_model: &Graph<F, O>,
outlet: OutletId,
name: impl Into<String>,
new_op: IO,
fact: F,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let tap = patch.tap_model(patched_model, outlet)?;
let new_id = patch.add_node(name, new_op, tvec!(fact))?;
patch.add_edge(tap, InletId::new(new_id, 0))?;
patch.shunt_outside(patched_model, outlet, OutletId::new(new_id, 0))?;
Ok(patch)
}
/// Apply all changes in the patch to the target model.
pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
let prior_target_inputs = target.input_outlets()?.len();
let prior_target_outputs = target.output_outlets()?.len();
let ModelPatch {
model: patch,
incoming: mut mapping,
shunt_outlet_by,
obliterate,
inputs: replaced_inputs,
..
} = self;
let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
let mut model_input_outlets = target.input_outlets()?.to_vec();
for node in patch.nodes {
if <Graph<F, O>>::is_source(&node.op)
&& mapping.contains_key(&OutletId::new(node.id, 0))
{
// this is a tap
continue;
}
let Node { id: patch_node_id, name, inputs, op, outputs } = node;
let n_outputs = outputs.len();
for dup in 0..target.nodes.len() {
if target.node(dup).op().same_as(op.as_ref())
&& inputs.len() == target.node(dup).inputs.len()
&& inputs
.iter()
.zip(target.node(dup).inputs.iter())
.all(|(patch_input, d)| mapping[patch_input] == *d)
{
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
}
continue;
}
}
let facts = outputs.into_iter().map(|of| of.fact).collect();
let added_node_id = target.add_node(name, op, facts)?;
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
}
all_inputs.insert(added_node_id, inputs);
if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
// this is actually an input replacement
model_input_outlets.iter_mut().for_each(|oo| {
if oo.node == replaced_inputs[&patch_node_id] {
oo.node = added_node_id;
}
});
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (outlet, by) in shunt_outlet_by {
let replace_by = mapping[&by];
let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
for succ in succs {
target.add_edge(replace_by, succ)?;
}
for o in target.outputs.iter_mut() {
if *o == outlet {
*o = replace_by;
}
}
if let Some(label) = target.outlet_labels.remove(&outlet) {
target.set_outlet_label(replace_by, label)?;
}
}
if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
bail!("Duplicate usage of node as output");
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (node, inputs) in all_inputs {
for (ix, input) in inputs.into_iter().enumerate() {
target.add_edge(mapping[&input], InletId::new(node, ix))?;
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for node in obliterate {
target.node_mut(node).op = target.create_dummy();
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
target.set_input_outlets(&model_input_outlets)?;
Ok(())
}More examples
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
fn translate_node(
&self,
source: &Graph<TI1, O1>,
node: &Node<TI1, O1>,
target: &mut Graph<TI2, O2>,
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
let node_is_input =
(0..node.outputs.len()).all(|o| source.inputs.contains(&(node.id, o).into()));
if node_is_input {
(0..node.outputs.len())
.map(|i| {
target.add_source(
if node.outputs.len() > 1 {
format!("{}-{}", node.name, i)
} else {
node.name.to_string()
},
TI2::try_from(&node.outputs[i].fact)?,
)
})
.collect()
} else {
let new_op = O2::try_from(&node.op)?;
let facts = node
.outputs
.iter()
.map(|of| Ok(TI2::try_from(&of.fact)?))
.collect::<TractResult<TVec<_>>>()?;
let new_id = target.add_node(node.name.clone(), new_op, facts)?;
for (ix, o) in node.inputs.iter().enumerate() {
target.add_edge(mapping[o], InletId::new(new_id, ix))?
}
Ok(node.outputs.iter().enumerate().map(|(ix, _)| OutletId::new(new_id, ix)).collect())
}
}39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
fn wire_node(
&mut self,
name: impl Into<String>,
op: impl Into<Box<dyn TypedOp>>,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let op = op.into();
let name = name.into();
{
let output_facts = || -> TractResult<TVec<TypedFact>> {
let input_facts = inputs
.iter()
.map(|o| self.outlet_fact(*o))
.collect::<TractResult<TVec<_>>>()?;
let facts = op.output_facts(&input_facts).context("in output_facts invocation")?;
if input_facts.iter().all(|f| f.konst.is_some()) && op.is_stateless() {
let tensors = input_facts
.iter()
.map(|f| f.konst.clone().unwrap().into_tvalue())
.collect::<TVec<_>>();
if let Ok(outputs) = op.eval(tensors) {
return Ok(outputs.into_iter().map(|t| TypedFact::from(&*t)).collect());
}
}
Ok(facts)
};
let output_facts = output_facts()
.with_context(|| format!("wiring {} ({:?}), determining output_facts", name, op))?;
let id = self.add_node(&name, &op, output_facts)?;
inputs
.iter()
.enumerate()
.try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
TractResult::Ok(
self.node(id)
.outputs
.iter()
.enumerate()
.map(|(ix, _)| OutletId::new(id, ix))
.collect(),
)
}
.with_context(|| format!("Wiring node \"{}\", {:?}", name, op))
}sourcepub fn input_outlets(&self) -> TractResult<&[OutletId]>
pub fn input_outlets(&self) -> TractResult<&[OutletId]>
Get model inputs.
Examples found in repository?
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
let input = self.input_outlets()?[ix];
self.outlet_fact(input)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let input = self.input_outlets()?[ix];
self.outlet_fact_mut(input)
}
/// Set the `ix`-th input tensor type information.
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
let outlet = self.inputs[input];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th input tensor type information and return `self`.
pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
self.set_input_fact(input, fact)?;
Ok(self)
}
// Outputs
/// Get model outputs.
pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.outputs)
}
/// Guess outputs from the topology: node or nodes with no successors.
pub fn auto_outputs(&mut self) -> TractResult<()> {
let outputs = self
.nodes
.iter()
.flat_map(|n| {
let id = n.id;
n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
(OutletId::new(id, ix), output_fact.successors.len())
})
})
.filter(|(_f, succs)| *succs == 0)
.map(|(f, _)| f)
.collect();
self.outputs = outputs;
Ok(())
}
/// Change model outputs.
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
self.outputs = outputs.to_vec();
Ok(())
}
/// Change model outputs and return `self`.
pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
self.set_output_outlets(outputs)?;
Ok(self)
}
/// Set model outputs by node names.
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut labels: HashMap<Cow<str>, OutletId> =
self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
for n in self.nodes() {
for ix in 0..n.outputs.len() {
labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
}
}
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| {
let s = s.as_ref();
labels
.get(s)
.cloned()
.or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
.ok_or_else(|| format_err!("Node {} not found", s))
})
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
/// Set model outputs by node names and return `self`.
pub fn with_output_names(
mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_output_names(outputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
/// Set the `ix`-th output tensor type information.
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th output tensor type information and return `self`.
pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
self.set_output_fact(output, fact)?;
Ok(self)
}
// nodes and their facts
/// Iterate over all node names.
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{}\"", name))
}
/// Find a node by its name.
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
/// Borrow mutably a node by its name.
pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&mut self.nodes[id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
Ok(())
}
/// Find a node by its id.
pub fn node(&self, id: usize) -> &Node<F, O> {
&self.nodes[id]
}
/// Find a node by its id.
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
&mut self.nodes[id]
}
/// Access the nodes table.
pub fn nodes(&self) -> &[Node<F, O>] {
&self.nodes
}
/// Access the nodes table.
pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
&mut self.nodes
}
/// Get input and output tensor information for a node.
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}
// outlet labels
/// Get label for an outlet.
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
/// Set label for an outlet.
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
self.outlet_labels.insert(outlet, label);
Ok(())
}
/// Set label for an outlet and return `self`.
pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
self.set_outlet_label(outlet, label)?;
Ok(self)
}
/// Find outlet by label.
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
}
// misc
/// Computes an evalutation order for the graph inputs and outputs
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
Ok(())
}
/// Performs a sanity check on network connections.
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}
impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
O: fmt::Debug
+ fmt::Display
+ From<crate::ops::konst::Const>
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ Hash
+ 'static,
{
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor,
) -> TractResult<OutletId> {
let v = v.into_arc_tensor();
let fact = F::from(v.clone());
let name = name.into();
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
}
}
impl<F, O> fmt::Display for Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}
}
impl<F, O> Graph<F, O>
where
F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
O: std::fmt::Display
+ std::fmt::Debug
+ Clone
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ 'static
+ std::hash::Hash
+ for<'a> std::convert::From<&'a O>,
Graph<F, O>: SpecialOps<F, O>,
{
#[cfg(debug_assertions)]
pub fn check_compact(&self) -> TractResult<()> {
let order = self.eval_order()?;
let useless_sources = self
.input_outlets()?
.iter()
.filter(|io| {
self.outlet_successors(**io).len() == 0
&& !self.output_outlets().unwrap().contains(io)
})
.count();
if order.len() + useless_sources != self.nodes.len() {
bail!(
"Eval order is {} long, nodes are {}, including {} unused sources",
order.len(),
self.nodes.len(),
useless_sources
);
}
if (0..order.len()).any(|ix| order[ix] != ix) {
bail!("eval order is not trivial");
}
let mut seen = std::collections::HashSet::new();
for (ix, n) in self.nodes.iter().enumerate() {
if ix != n.id {
bail!("Invalid node id: position is {}, node is {}", ix, n);
}
if seen.contains(&n.name) {
eprintln!("{}", self);
bail!("duplicate name {}", n.name);
}
seen.insert(&n.name);
}
Ok(())
}More examples
341 342 343 344 345 346 347 348 349 350 351 352
pub fn for_model(model: &TypedModel) -> TractResult<Invariants> {
full_axis_tracking(model)?
.into_iter()
.map(|tracking| {
let inputs =
model.input_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
let outputs =
model.output_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
Ok(AxisInfo { inputs, outputs, disposable: tracking.disposable, period: 1 })
})
.collect()
}24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
if self.0.insert((n, suggestion.clone())) {
let outlet = suggestion.0.as_outlet(model.node(n));
let change = AxisChange { outlet, op: suggestion.1.clone() };
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
}
Ok(None)
}73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
pub fn new_for_outputs_and_deps(
model: M,
outputs: &[OutletId],
deps: &[(usize, usize)],
) -> TractResult<SimplePlan<F, O, M>> {
let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
let order = eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
for (step, node) in order.iter().enumerate() {
for i in &model.borrow().node(*node).inputs {
values_needed_until_step[i.node] = step;
}
}
for o in outputs.iter() {
values_needed_until_step[o.node] = order.len();
}
let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
if flush_at != 0 {
flush_lists[flush_at].push(node)
}
}
let mut symbols: std::collections::HashSet<Symbol> = Default::default();
for node in &model.borrow().nodes {
for output in &node.outputs {
if let Ok(fact) = output.fact.to_typed_fact() {
symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
}
}
}
Ok(SimplePlan {
model,
order,
flush_lists,
outputs: outputs.to_vec(),
has_unresolved_symbols: !symbols.is_empty(),
_casper: PhantomData,
})
}
pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut state = SimpleState::new(self)?;
state.run(inputs)
}
pub fn model(&self) -> &Graph<F, O> {
self.model.borrow()
}
}
#[derive(Clone, Debug)]
pub struct SimpleState<F, O, M, P>
where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
M: Borrow<Graph<F, O>> + Hash,
P: Borrow<SimplePlan<F, O, M>>,
{
plan: P,
pub states: Vec<Option<Box<dyn OpState>>>,
pub session_state: SessionState,
pub values: Vec<Option<TVec<TValue>>>,
_phantom: PhantomData<(M, F, O)>,
}
impl<F, O, M, P> SimpleState<F, O, M, P>
where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
M: Borrow<Graph<F, O>> + Hash,
P: Borrow<SimplePlan<F, O, M>> + Clone,
{
pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
let values = vec![None; plan.borrow().model.borrow().nodes().len()];
let mut session = SessionState::default();
let model = plan.borrow().model();
let states: Vec<Option<Box<dyn OpState>>> = model
.nodes()
.iter()
.map(|n: &Node<F, O>| n.op().state(&mut session, n.id))
.collect::<TractResult<_>>()?;
Ok(SimpleState { plan, states, session_state: session, values, _phantom: PhantomData })
}
/// Reset wires state.
pub fn reset_turn(&mut self) -> TractResult<()> {
self.values.iter_mut().for_each(|s| *s = None);
Ok(())
}
/// Reset op inner state.
pub fn reset_op_states(&mut self) -> TractResult<()> {
let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
*states = plan
.borrow()
.model()
.nodes()
.iter()
.map(|n| n.op().state(session_state, n.id))
.collect::<TractResult<_>>()?;
Ok(())
}
pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
self.run_plan_with_eval(inputs, self::eval)
}
pub fn exec(&mut self) -> TractResult<()> {
self.exec_plan_with_eval(self::eval)
}
pub fn run_plan_with_eval<Eval, E>(
&mut self,
inputs: TVec<TValue>,
eval: Eval,
) -> TractResult<TVec<TValue>>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
self.set_inputs(inputs)?;
self.exec_plan_with_eval(eval)?;
let outputs = self.outputs()?;
self.reset_turn()?;
Ok(outputs)
}
pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
{
let &mut SimpleState {
ref plan,
ref mut session_state,
ref mut states,
ref mut values,
..
} = self;
let plan = plan.borrow();
let model = plan.model().borrow();
for (step, n) in plan.order.iter().enumerate() {
let node = model.node(*n);
trace!("Running step {}, node {}", step, node);
let mut inputs: TVec<TValue> = tvec![];
for i in &node.inputs {
trace!(" use input {:?}", i);
let prec_node = model.node(i.node);
let prec = values[i.node].as_ref().ok_or_else(|| {
format_err!("Computing {}, precursor {} not done:", node, prec_node)
})?;
inputs.push(prec[i.slot].clone())
}
for flush in &plan.flush_lists[step] {
trace!(" Ran {} can now flush {}", node, model.node(*flush));
values[*flush] = None;
}
if cfg!(debug_assertions) {
let facts = model.node_input_facts(node.id)?;
if facts.len() != inputs.len() {
bail!(
"Evaluating {}: expected {} inputs, got {}",
node,
facts.len(),
inputs.len()
);
}
for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: input {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
.map_err(|e| e.into())?;
if plan.has_unresolved_symbols {
for (o, v) in node.outputs.iter().zip(vs.iter()) {
if let Ok(f) = o.fact.to_typed_fact() {
for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
Self::resolve(
&mut session_state.resolved_symbols,
&dim_abstract,
*dim_concrete as i64,
);
}
}
}
}
if cfg!(debug_assertions) {
let facts = model.node_output_facts(node.id)?;
if facts.len() != vs.len() {
bail!(
"Evaluating {}: expected {} outputs, got {}",
node,
facts.len(),
vs.len()
);
}
for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
if node.outputs[ix].successors.len() == 0 {
continue;
}
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: output {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
values[node.id] = Some(vs);
}
}
Ok(())
}
pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
ensure!(
inputs.len() == self.model().inputs.len(),
"Wrong number of inputs for model. Expected {} got {}",
self.model().inputs.len(),
inputs.len()
);
for (ix, t) in inputs.into_iter().enumerate() {
self.set_input(ix, t)?
}
Ok(())
}
fn resolve(symbols: &mut SymbolValues, expected: &TDim, provided: i64) {
match expected {
TDim::Sym(s) => symbols[s] = Some(provided),
TDim::MulInt(x, expr) => Self::resolve(symbols, expr, provided / *x),
_ => (),
}
}
pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
let outlet: OutletId = *self
.model()
.input_outlets()?
.get(input)
.ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?;
let SimpleState { plan, session_state, .. } = self;
let plan = (*plan).borrow();
let model = plan.model.borrow();
if let Ok(fact) = model.outlet_fact(outlet)?.to_typed_fact() {
for (expected, provided) in fact.shape.iter().zip(t.shape()) {
Self::resolve(&mut session_state.resolved_symbols, &expected, *provided as i64)
}
}
let fact = self.plan.borrow().model().outlet_fact(outlet)?;
ensure!(
fact.matches(&t, Some(&self.session_state.resolved_symbols))
.with_context(|| format!("Setting input {}", input))?,
"Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})",
input,
t.shape(),
t.datum_type(),
fact
);
self.session_state.inputs.insert(outlet.node, t);
Ok(())
}26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
fn translate_model_with_mappings(
&self,
source: &Graph<TI1, O1>,
) -> TractResult<(Graph<TI2, O2>, HashMap<OutletId, OutletId>)> {
let mut target = Graph::default();
let mut mapping = HashMap::new();
for old_id in source.eval_order()? {
let node = source.node(old_id);
trace!("Translating {} {:?}", node, self);
let outlets = self
.translate_node(source, node, &mut target, &mapping)
.with_context(|| format!("Translating node {} {:?}", node, self))?;
for (ix, outlet) in outlets.into_iter().enumerate() {
mapping.insert(OutletId::new(node.id, ix), outlet);
if let Some(label) = source.outlet_label(OutletId::new(node.id, ix)) {
target.set_outlet_label(outlet, label.to_string())?;
}
}
}
// do not drop inputs, even if they are useless, to maintain interface
for i in source.input_outlets()? {
if !mapping.contains_key(i) {
let node = source.node(i.node);
trace!("Translate useless source {}", node);
let outlets = self
.translate_node(source, node, &mut target, &mapping)
.with_context(|| format!("Translating input {} {:?}", node, self))?;
mapping.insert(*i, outlets[0]);
}
}
// maintaining order of i/o interface
target.inputs = source.input_outlets()?.iter().map(|i| mapping[i]).collect();
target.outputs = source.output_outlets()?.iter().map(|o| mapping[o]).collect();
target.symbol_table = source.symbol_table.clone();
target.properties = source.properties.clone();
Ok((target, mapping))
}sourcepub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()>
pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()>
Change model inputs.
Examples found in repository?
More examples
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
fn declutter_discard_unused_input_mapping(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
let source_node = self.body.node(input.node);
if source_node.outputs[0].successors.len() == 0
&& !self.body.output_outlets()?.contains(input)
{
let mut new_inputs = node.inputs.clone();
let slot = match &self.input_mapping[inner_input_id] {
InputMapping::Full { slot } => Some(*slot),
InputMapping::Scan(info) => Some(info.slot),
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(n) => Some(*n),
_ => None,
},
};
let mut new_mappings: Vec<_> = self.input_mapping.clone();
new_mappings.remove(inner_input_id);
if let Some(slot) = slot {
new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
}
let mut model_inputs = self.body.input_outlets()?.to_vec();
if let Some(slot) = slot {
new_inputs.remove(slot);
}
model_inputs.remove(inner_input_id);
let mut body = self.body.clone();
let mut patch = TypedModelPatch::default();
patch.obliterate(source_node.id)?;
patch.apply(&mut body)?;
body.set_input_outlets(&model_inputs)?;
body.declutter()?;
let op = Self {
body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
input_mapping: new_mappings,
decluttered: true,
output_mapping: self.output_mapping.clone(),
};
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
}
}
Ok(None)
}251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
let prior_target_inputs = target.input_outlets()?.len();
let prior_target_outputs = target.output_outlets()?.len();
let ModelPatch {
model: patch,
incoming: mut mapping,
shunt_outlet_by,
obliterate,
inputs: replaced_inputs,
..
} = self;
let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
let mut model_input_outlets = target.input_outlets()?.to_vec();
for node in patch.nodes {
if <Graph<F, O>>::is_source(&node.op)
&& mapping.contains_key(&OutletId::new(node.id, 0))
{
// this is a tap
continue;
}
let Node { id: patch_node_id, name, inputs, op, outputs } = node;
let n_outputs = outputs.len();
for dup in 0..target.nodes.len() {
if target.node(dup).op().same_as(op.as_ref())
&& inputs.len() == target.node(dup).inputs.len()
&& inputs
.iter()
.zip(target.node(dup).inputs.iter())
.all(|(patch_input, d)| mapping[patch_input] == *d)
{
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
}
continue;
}
}
let facts = outputs.into_iter().map(|of| of.fact).collect();
let added_node_id = target.add_node(name, op, facts)?;
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
}
all_inputs.insert(added_node_id, inputs);
if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
// this is actually an input replacement
model_input_outlets.iter_mut().for_each(|oo| {
if oo.node == replaced_inputs[&patch_node_id] {
oo.node = added_node_id;
}
});
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (outlet, by) in shunt_outlet_by {
let replace_by = mapping[&by];
let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
for succ in succs {
target.add_edge(replace_by, succ)?;
}
for o in target.outputs.iter_mut() {
if *o == outlet {
*o = replace_by;
}
}
if let Some(label) = target.outlet_labels.remove(&outlet) {
target.set_outlet_label(replace_by, label)?;
}
}
if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
bail!("Duplicate usage of node as output");
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (node, inputs) in all_inputs {
for (ix, input) in inputs.into_iter().enumerate() {
target.add_edge(mapping[&input], InletId::new(node, ix))?;
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for node in obliterate {
target.node_mut(node).op = target.create_dummy();
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
target.set_input_outlets(&model_input_outlets)?;
Ok(())
}sourcepub fn with_input_outlets(self, inputs: &[OutletId]) -> TractResult<Self>
pub fn with_input_outlets(self, inputs: &[OutletId]) -> TractResult<Self>
Change model inputs and return self.
sourcepub fn set_input_names(
&mut self,
inputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<()>
pub fn set_input_names(
&mut self,
inputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<()>
Set model inputs by the node name.
sourcepub fn with_input_names(
self,
inputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<Self>
pub fn with_input_names(
self,
inputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<Self>
Set model inputs by the node name and return self.
sourcepub fn input_fact(&self, ix: usize) -> TractResult<&F>
pub fn input_fact(&self, ix: usize) -> TractResult<&F>
Get the ix-th input tensor type information.
sourcepub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F>
pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F>
Get the ix-th input tensor type information, mutably.
sourcepub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()>
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()>
Set the ix-th input tensor type information.
sourcepub fn with_input_fact(self, input: usize, fact: F) -> TractResult<Self>
pub fn with_input_fact(self, input: usize, fact: F) -> TractResult<Self>
Set the ix-th input tensor type information and return self.
sourcepub fn output_outlets(&self) -> TractResult<&[OutletId]>
pub fn output_outlets(&self) -> TractResult<&[OutletId]>
Get model outputs.
Examples found in repository?
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
/// Set the `ix`-th output tensor type information.
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th output tensor type information and return `self`.
pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
self.set_output_fact(output, fact)?;
Ok(self)
}
// nodes and their facts
/// Iterate over all node names.
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{}\"", name))
}
/// Find a node by its name.
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
/// Borrow mutably a node by its name.
pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&mut self.nodes[id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
Ok(())
}
/// Find a node by its id.
pub fn node(&self, id: usize) -> &Node<F, O> {
&self.nodes[id]
}
/// Find a node by its id.
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
&mut self.nodes[id]
}
/// Access the nodes table.
pub fn nodes(&self) -> &[Node<F, O>] {
&self.nodes
}
/// Access the nodes table.
pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
&mut self.nodes
}
/// Get input and output tensor information for a node.
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}
// outlet labels
/// Get label for an outlet.
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
/// Set label for an outlet.
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
self.outlet_labels.insert(outlet, label);
Ok(())
}
/// Set label for an outlet and return `self`.
pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
self.set_outlet_label(outlet, label)?;
Ok(self)
}
/// Find outlet by label.
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
}
// misc
/// Computes an evalutation order for the graph inputs and outputs
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
Ok(())
}
/// Performs a sanity check on network connections.
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}
impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
O: fmt::Debug
+ fmt::Display
+ From<crate::ops::konst::Const>
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ Hash
+ 'static,
{
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor,
) -> TractResult<OutletId> {
let v = v.into_arc_tensor();
let fact = F::from(v.clone());
let name = name.into();
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
}
}
impl<F, O> fmt::Display for Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}
}
impl<F, O> Graph<F, O>
where
F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
O: std::fmt::Display
+ std::fmt::Debug
+ Clone
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ 'static
+ std::hash::Hash
+ for<'a> std::convert::From<&'a O>,
Graph<F, O>: SpecialOps<F, O>,
{
#[cfg(debug_assertions)]
pub fn check_compact(&self) -> TractResult<()> {
let order = self.eval_order()?;
let useless_sources = self
.input_outlets()?
.iter()
.filter(|io| {
self.outlet_successors(**io).len() == 0
&& !self.output_outlets().unwrap().contains(io)
})
.count();
if order.len() + useless_sources != self.nodes.len() {
bail!(
"Eval order is {} long, nodes are {}, including {} unused sources",
order.len(),
self.nodes.len(),
useless_sources
);
}
if (0..order.len()).any(|ix| order[ix] != ix) {
bail!("eval order is not trivial");
}
let mut seen = std::collections::HashSet::new();
for (ix, n) in self.nodes.iter().enumerate() {
if ix != n.id {
bail!("Invalid node id: position is {}, node is {}", ix, n);
}
if seen.contains(&n.name) {
eprintln!("{}", self);
bail!("duplicate name {}", n.name);
}
seen.insert(&n.name);
}
Ok(())
}More examples
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
pub fn new(model: M) -> TractResult<SimplePlan<F, O, M>> {
let outputs = model.borrow().output_outlets()?.to_vec();
Self::new_for_outputs(model, &outputs)
}
/// This contructor returns a plan that will compute the specified output.
pub fn new_for_output(model: M, output: OutletId) -> TractResult<SimplePlan<F, O, M>> {
Self::new_for_outputs_and_deps(model, &[output], &[])
}
/// This contructor returns a plan that will compute all specified outputs in one pass.
pub fn new_for_outputs(model: M, outputs: &[OutletId]) -> TractResult<SimplePlan<F, O, M>> {
Self::new_for_outputs_and_deps(model, outputs, &[])
}
pub fn new_for_outputs_and_deps(
model: M,
outputs: &[OutletId],
deps: &[(usize, usize)],
) -> TractResult<SimplePlan<F, O, M>> {
let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
let order = eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
for (step, node) in order.iter().enumerate() {
for i in &model.borrow().node(*node).inputs {
values_needed_until_step[i.node] = step;
}
}
for o in outputs.iter() {
values_needed_until_step[o.node] = order.len();
}
let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
if flush_at != 0 {
flush_lists[flush_at].push(node)
}
}
let mut symbols: std::collections::HashSet<Symbol> = Default::default();
for node in &model.borrow().nodes {
for output in &node.outputs {
if let Ok(fact) = output.fact.to_typed_fact() {
symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
}
}
}
Ok(SimplePlan {
model,
order,
flush_lists,
outputs: outputs.to_vec(),
has_unresolved_symbols: !symbols.is_empty(),
_casper: PhantomData,
})
}
pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut state = SimpleState::new(self)?;
state.run(inputs)
}
pub fn model(&self) -> &Graph<F, O> {
self.model.borrow()
}
}
#[derive(Clone, Debug)]
pub struct SimpleState<F, O, M, P>
where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
M: Borrow<Graph<F, O>> + Hash,
P: Borrow<SimplePlan<F, O, M>>,
{
plan: P,
pub states: Vec<Option<Box<dyn OpState>>>,
pub session_state: SessionState,
pub values: Vec<Option<TVec<TValue>>>,
_phantom: PhantomData<(M, F, O)>,
}
impl<F, O, M, P> SimpleState<F, O, M, P>
where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
M: Borrow<Graph<F, O>> + Hash,
P: Borrow<SimplePlan<F, O, M>> + Clone,
{
pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
let values = vec![None; plan.borrow().model.borrow().nodes().len()];
let mut session = SessionState::default();
let model = plan.borrow().model();
let states: Vec<Option<Box<dyn OpState>>> = model
.nodes()
.iter()
.map(|n: &Node<F, O>| n.op().state(&mut session, n.id))
.collect::<TractResult<_>>()?;
Ok(SimpleState { plan, states, session_state: session, values, _phantom: PhantomData })
}
/// Reset wires state.
pub fn reset_turn(&mut self) -> TractResult<()> {
self.values.iter_mut().for_each(|s| *s = None);
Ok(())
}
/// Reset op inner state.
pub fn reset_op_states(&mut self) -> TractResult<()> {
let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
*states = plan
.borrow()
.model()
.nodes()
.iter()
.map(|n| n.op().state(session_state, n.id))
.collect::<TractResult<_>>()?;
Ok(())
}
pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
self.run_plan_with_eval(inputs, self::eval)
}
pub fn exec(&mut self) -> TractResult<()> {
self.exec_plan_with_eval(self::eval)
}
pub fn run_plan_with_eval<Eval, E>(
&mut self,
inputs: TVec<TValue>,
eval: Eval,
) -> TractResult<TVec<TValue>>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
self.set_inputs(inputs)?;
self.exec_plan_with_eval(eval)?;
let outputs = self.outputs()?;
self.reset_turn()?;
Ok(outputs)
}
pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
{
let &mut SimpleState {
ref plan,
ref mut session_state,
ref mut states,
ref mut values,
..
} = self;
let plan = plan.borrow();
let model = plan.model().borrow();
for (step, n) in plan.order.iter().enumerate() {
let node = model.node(*n);
trace!("Running step {}, node {}", step, node);
let mut inputs: TVec<TValue> = tvec![];
for i in &node.inputs {
trace!(" use input {:?}", i);
let prec_node = model.node(i.node);
let prec = values[i.node].as_ref().ok_or_else(|| {
format_err!("Computing {}, precursor {} not done:", node, prec_node)
})?;
inputs.push(prec[i.slot].clone())
}
for flush in &plan.flush_lists[step] {
trace!(" Ran {} can now flush {}", node, model.node(*flush));
values[*flush] = None;
}
if cfg!(debug_assertions) {
let facts = model.node_input_facts(node.id)?;
if facts.len() != inputs.len() {
bail!(
"Evaluating {}: expected {} inputs, got {}",
node,
facts.len(),
inputs.len()
);
}
for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: input {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
.map_err(|e| e.into())?;
if plan.has_unresolved_symbols {
for (o, v) in node.outputs.iter().zip(vs.iter()) {
if let Ok(f) = o.fact.to_typed_fact() {
for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
Self::resolve(
&mut session_state.resolved_symbols,
&dim_abstract,
*dim_concrete as i64,
);
}
}
}
}
if cfg!(debug_assertions) {
let facts = model.node_output_facts(node.id)?;
if facts.len() != vs.len() {
bail!(
"Evaluating {}: expected {} outputs, got {}",
node,
facts.len(),
vs.len()
);
}
for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
if node.outputs[ix].successors.len() == 0 {
continue;
}
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: output {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
values[node.id] = Some(vs);
}
}
Ok(())
}
pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
ensure!(
inputs.len() == self.model().inputs.len(),
"Wrong number of inputs for model. Expected {} got {}",
self.model().inputs.len(),
inputs.len()
);
for (ix, t) in inputs.into_iter().enumerate() {
self.set_input(ix, t)?
}
Ok(())
}
fn resolve(symbols: &mut SymbolValues, expected: &TDim, provided: i64) {
match expected {
TDim::Sym(s) => symbols[s] = Some(provided),
TDim::MulInt(x, expr) => Self::resolve(symbols, expr, provided / *x),
_ => (),
}
}
pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
let outlet: OutletId = *self
.model()
.input_outlets()?
.get(input)
.ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?;
let SimpleState { plan, session_state, .. } = self;
let plan = (*plan).borrow();
let model = plan.model.borrow();
if let Ok(fact) = model.outlet_fact(outlet)?.to_typed_fact() {
for (expected, provided) in fact.shape.iter().zip(t.shape()) {
Self::resolve(&mut session_state.resolved_symbols, &expected, *provided as i64)
}
}
let fact = self.plan.borrow().model().outlet_fact(outlet)?;
ensure!(
fact.matches(&t, Some(&self.session_state.resolved_symbols))
.with_context(|| format!("Setting input {}", input))?,
"Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})",
input,
t.shape(),
t.datum_type(),
fact
);
self.session_state.inputs.insert(outlet.node, t);
Ok(())
}
pub fn output(&self, id: usize) -> TractResult<&TValue> {
let outlet = self.model().output_outlets()?.get(id).with_context(|| {
format!(
"Required output {}, only have {}",
id,
self.model().output_outlets().unwrap().len()
)
})?;
let value: &TValue = self
.values
.get(outlet.node)
.context("node id for output beyond node values array")?
.as_ref()
.context("node is not an output")?
.get(outlet.slot)
.context("slot id too high")?;
Ok(value)
}341 342 343 344 345 346 347 348 349 350 351 352
pub fn for_model(model: &TypedModel) -> TractResult<Invariants> {
full_axis_tracking(model)?
.into_iter()
.map(|tracking| {
let inputs =
model.input_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
let outputs =
model.output_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
Ok(AxisInfo { inputs, outputs, disposable: tracking.disposable, period: 1 })
})
.collect()
}47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.with_index_outputs.is_some()
&& node.outputs[1].successors.len() == 0
&& !model.output_outlets()?.contains(&OutletId::new(node.id, 1))
{
let op = Self { with_index_outputs: None, ..self.clone() };
let mut patch = TypedModelPatch::default();
let mut wire = patch.tap_model(model, node.inputs[0])?;
wire = patch.wire_node(&node.name, op, &[wire])?[0];
patch.shunt_outside(model, node.id.into(), wire)?;
return Ok(Some(patch));
}
Ok(None)
}24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
if self.0.insert((n, suggestion.clone())) {
let outlet = suggestion.0.as_outlet(model.node(n));
let change = AxisChange { outlet, op: suggestion.1.clone() };
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
}
Ok(None)
}sourcepub fn auto_outputs(&mut self) -> TractResult<()>
pub fn auto_outputs(&mut self) -> TractResult<()>
Guess outputs from the topology: node or nodes with no successors.
sourcepub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()>
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()>
Change model outputs.
Examples found in repository?
More examples
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut model = TypedModel::default();
let mut wires: TVec<OutletId> = inputs
.iter()
.enumerate()
.map(|(ix, v)| {
model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
})
.collect::<TractResult<_>>()?;
let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
let wire = unsafe {
if self.q_params.is_some() {
let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
op_ref.wire_as_quant_im2col(
&mut model,
"im2col-adhoc",
inputs[0].datum_type(),
&wires,
)?
} else {
self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
}
};
model.set_output_outlets(&[wire])?;
model.into_runnable()?.run(inputs)
}262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
ensure!(
inputs[0].rank() == inputs[1].rank(),
"Rank mismatch {:?} vs {:?}",
inputs[0],
inputs[1]
);
let mut model = TypedModel::default();
let a = model.add_const("source_a", inputs[0].clone().into_arc_tensor())?;
let b = model.add_const("source_b", inputs[1].clone().into_arc_tensor())?;
let bias = model.add_const("source_bias", inputs[2].clone().into_arc_tensor())?;
let mut input_outlets = tvec![a, b, bias];
for (i, t) in inputs.iter().enumerate().skip(3) {
input_outlets
.push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
}
let mut params = self.params.as_outlet_ids(
&mut model,
"qmatmul_unary",
&input_outlets,
inputs[0].datum_type(),
inputs[1].datum_type(),
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;
let new_op = MatMul { axes: self.axes };
let result = model.wire_node("adhoc.matmul", new_op, &[a, b])?[0];
let result = wire_matmul_quant(
&mut model,
"adhoc",
a,
b,
Some(bias),
self.axes,
result,
self.output_type,
¶ms,
)?;
model.set_output_outlets(&[result])?;
model.into_runnable()?.run(tvec![])
}34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
ensure!(inputs[0].rank() == self.a.rank(), "Rank mismatch {:?} vs {:?}", inputs[0], self.a);
let mut model = TypedModel::default();
let t_a = self.a.offset_u8_as_i8();
let a = model.add_const("source_a", self.a.clone())?;
let b = model.add_const("source_b", inputs[0].clone().into_arc_tensor())?;
let bias = if let Some(bias) = self.bias.clone() {
Some(model.add_const("source_bias", bias)?)
} else {
None
};
let mut input_outlets = tvec![a];
for (i, t) in inputs.iter().enumerate().skip(1) {
input_outlets
.push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
}
let mut params = self.params.as_outlet_ids(
&mut model,
"qmatmul_unary",
&input_outlets,
self.a.datum_type(),
inputs[0].datum_type(),
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;
let new_op = MatMulUnary { a: t_a, axes: self.axes };
let result = model.wire_node("adhoc.matmul", new_op, &[b])?[0];
let result = wire_matmul_quant(
&mut model,
"adhoc",
a,
b,
bias,
self.axes,
result,
self.output_type,
¶ms,
)?;
model.set_output_outlets(&[result])?;
model.into_runnable()?.run(tvec![])
}6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
pub fn pull_downsample_over_scan(
model: &TypedModel,
scan_node: &TypedNode,
scan_op: &ops::scan::Scan,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.stride < 0 {
return Ok(None);
}
// introduce downsample at end of body
let mut downsampled_body = scan_op.body.clone();
downsampled_body.check_consistency()?;
let outputs = downsampled_body.output_outlets()?.to_owned();
let downsample_outputs = outputs
.into_iter()
.enumerate()
.map(|(ix, oo)| {
Ok(downsampled_body.wire_node(
format!("{}-{}", &down_node.name, ix),
down_op.clone(),
&[oo],
)?[0])
})
.collect::<TractResult<Vec<_>>>()?;
downsampled_body.set_output_outlets(&downsample_outputs)?;
downsampled_body.declutter()?;
downsampled_body.check_consistency()?;
// check if downsample ops introduced at end have swimmed up to scan inputs during declutter
for input in downsampled_body.input_outlets()? {
let input = downsampled_body.node(input.node);
if input.outputs[0]
.successors
.iter()
.any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
{
return Ok(None);
}
}
let inputs = downsampled_body.input_outlets()?.to_vec();
for input in inputs {
let node = &mut downsampled_body.node_mut(input.node);
let fact = &mut node.outputs[0].fact;
*fact = down_op.transform_fact(fact)?;
node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
for ds in downsamples {
TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
.apply(&mut downsampled_body)?;
}
}
downsampled_body.check_consistency()?;
let inner_model = downsampled_body.into_decluttered()?;
let mut new_scan = scan_op.clone();
new_scan.body = inner_model;
for input in &mut new_scan.input_mapping {
match input {
InputMapping::State { ref mut initializer } => {
if let StateInitializer::Value(ref v) = initializer {
let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
*initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
}
}
InputMapping::Scan(info) => {
if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
_ => (),
}
}
for output in &mut new_scan.output_mapping {
if let Some(d) = output.full_dim_hint.as_mut() {
*d = down_op.transform_dim(d)
}
if let Some(info) = &mut output.scan {
if info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
}
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for (ix, &i) in scan_node.inputs.iter().enumerate() {
let tap = patch.tap_model(model, i)?;
let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
inputs.push(ds);
}
let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
for ix in 0..scan_node.outputs.len() {
// FIXME need to check earlier on that all output are followed by a ds
let succ = scan_node.outputs[ix].successors[0].node;
patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
}
Ok(Some(patch))
}sourcepub fn with_output_outlets(self, outputs: &[OutletId]) -> TractResult<Self>
pub fn with_output_outlets(self, outputs: &[OutletId]) -> TractResult<Self>
Change model outputs and return self.
sourcepub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<()>
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<()>
Set model outputs by node names.
sourcepub fn with_output_names(
self,
outputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<Self>
pub fn with_output_names(
self,
outputs: impl IntoIterator<Item = impl AsRef<str>>
) -> TractResult<Self>
Set model outputs by node names and return self.
sourcepub fn output_fact(&self, ix: usize) -> TractResult<&F>
pub fn output_fact(&self, ix: usize) -> TractResult<&F>
Get the ix-th input tensor type information.
Examples found in repository?
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678
fn declutter_pull_constant_outputs(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(slot) = mapping.last_value_slot {
if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
let inner_node = self.body.output_outlets()?[model_output_ix].node;
let inner_node = self.body.node(inner_node);
let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}
fn declutter_pull_batcheable_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(info) = mapping.scan {
let emitter_outlet = self.body.output_outlets()?[model_ix];
let emitter_node = self.body.node(emitter_outlet.node);
if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
|| mapping.state
|| mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
{
// continue if both last_value and full values are exported
continue;
}
let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
else {
continue;
};
let mut new_body = self.body.clone();
let mut new_output_mapping = self.output_mapping.clone();
let mut new_scan_outputs = node.outputs.len();
let mut outer_slots = vec![];
for input in &emitter_node.inputs {
if new_body.outputs.iter().all(|o| o != input) {
new_output_mapping.push(OutputMapping::default());
new_body.outputs.push(*input);
}
let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
let mut mapping = &mut new_output_mapping[body_output_id];
let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
if mapping.last_value_slot.is_none() {
mapping.last_value_slot = Some(new_scan_outputs);
}
new_scan_outputs += 1;
mapping.last_value_slot.unwrap()
} else {
if mapping.scan.is_none() {
mapping.scan = Some(ScanInfo {
slot: new_scan_outputs,
axis: axis_before,
chunk: info.chunk,
});
new_scan_outputs += 1;
}
mapping.scan.unwrap().slot
};
outer_slots.push(outer_slot);
}
let mut outside_patch = TypedModelPatch::new(format!(
"Outside patch for output extraction of {}",
emitter_node
));
let inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let new_op = Self {
input_mapping: self.input_mapping.clone(),
output_mapping: new_output_mapping,
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
let output = mapping.scan.unwrap();
let inputs =
outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
let wire = outside_patch.wire_node(
&*emitter_node.name,
emitter_node.op.clone(),
&inputs,
)?[0];
outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
for output_slot in 0..node.outputs.len() {
if output_slot != output.slot {
outside_patch.shunt_outside(
model,
OutletId::new(node.id, output_slot),
OutletId::new(scan_outputs[0].node, output_slot),
)?;
}
}
return Ok(Some(outside_patch));
}
}
Ok(None)
}
fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
let input_state_outlets = self
.input_mapping
.iter()
.zip(self.body.input_outlets()?.iter())
.filter(|(m, _)| m.as_state().is_some())
.map(|(_, o)| o);
let output_state_outlets = self
.output_mapping
.iter()
.zip(self.body.output_outlets()?.iter())
.filter(|(m, _)| m.state)
.map(|(_, o)| o);
Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
}
fn body_exposed_outlets(&self) -> TractResult<TVec<OutletId>> {
let input_outlets = self
.input_mapping
.iter()
.zip(self.body.input_outlets()?.iter())
.filter(|(m, _)| !m.invisible())
.map(|(_, o)| o);
let output_outlets = self
.output_mapping
.iter()
.zip(self.body.output_outlets()?.iter())
.filter(|(m, _)| !m.invisible())
.map(|(_, o)| o);
Ok(input_outlets.chain(output_outlets).cloned().collect())
}
fn try_body_axes_change(
&self,
change: AxisChange,
locked_interface: bool,
) -> TractResult<Option<AxisChangeConsequence>> {
self.body.check_consistency()?;
let interface = self.body_exposed_outlets()?;
let (patch, body_changed_wires) = if let Some(changes) =
crate::ops::change_axes::change_axes(
&self.body,
&change,
if locked_interface { &interface } else { &[] },
&self.body_bounds()?,
)? {
changes
} else {
return Ok(None);
};
let mut body = self.body.clone();
patch.apply(&mut body)?;
body.compact()?;
let mut wire_changes = tvec!();
let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
for (ix, m) in input_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::In(ix))
.map(|pair| pair.1.clone())
{
if let Some(slot) = m.slot() {
wire_changes.push((InOut::In(slot), change.clone()));
}
match &*m {
InputMapping::Full { .. } => (),
&InputMapping::Scan(info) => {
if let Some(axis) = change.transform_axis(info.axis) {
*m = InputMapping::Scan(ScanInfo { axis, ..info });
} else {
return Ok(None);
};
}
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(_) => (),
StateInitializer::Value(ref v) => {
let mut v = v.clone().into_tensor();
change.change_tensor(&mut v, false)?;
*m = InputMapping::State {
initializer: StateInitializer::Value(v.into_arc_tensor()),
};
}
},
};
}
}
let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
for (ix, m) in output_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::Out(ix))
.map(|pair| pair.1.clone())
{
if let Some(info) = m.scan.as_mut() {
if let Some(new_axis) = change.transform_axis(info.axis) {
info.axis = new_axis;
} else {
return Ok(None);
}
wire_changes.push((InOut::Out(info.slot), change.clone()));
}
if let Some(slot) = m.last_value_slot {
wire_changes.push((InOut::Out(slot), change.clone()));
}
};
}
body.check_consistency()?;
let op = Some(Box::new(Scan {
body,
input_mapping,
output_mapping,
decluttered: false,
..self.clone()
}) as _);
Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
}
}
impl Op for Scan {
fn name(&self) -> Cow<str> {
"Scan".into()
}
fn info(&self) -> TractResult<Vec<String>> {
let mut lines = vec![];
for (ix, im) in self.input_mapping.iter().enumerate() {
lines.push(format!("Model input #{}: {:?}", ix, im));
}
for (ix, om) in self.output_mapping.iter().enumerate() {
lines.push(format!("Model output #{}: {:?}", ix, om));
}
Ok(lines)
}
fn validation(&self) -> Validation {
Validation::Rounding
}
op_as_typed_op!();
}
impl EvalOp for Scan {
fn is_stateless(&self) -> bool {
false
}
fn state(
&self,
session: &mut SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
self.to_codegen_op(false)?.state(session, node_id)
}
}
impl TypedOp for Scan {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut outputs = tvec!();
let iters = {
let info = self.input_mapping.iter().flat_map(|it| it.as_scan()).next().unwrap();
inputs[info.slot].shape[info.axis].clone().div_ceil(info.chunk.unsigned_abs() as u64)
};
for (ix, output) in self.output_mapping.iter().enumerate() {
let fact = self.body.output_fact(ix)?;
if let Some(info) = output.scan {
let mut shape = fact.shape.clone();
let scanning_dim =
output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
shape.set(info.axis, scanning_dim);
outputs.push((info.slot, fact.datum_type.fact(shape)));
}
if let Some(slot) = output.last_value_slot {
outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
}
}
outputs.sort_by_key(|a| a.0);
anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
Ok(outputs)
}More examples
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
fn eval(
&mut self,
session: &mut SessionState,
_op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let State { op, ref mut hidden_state, ref mut position, ref mut model_state } = self;
// initialize state at first pass
if hidden_state.len() == 0 {
for input in &op.input_mapping {
if let InputMapping::State { initializer } = input {
hidden_state.push(match initializer {
StateInitializer::FromInput(slot) => inputs[*slot].clone(),
StateInitializer::Value(v) => (**v).to_owned().into_tvalue(),
});
}
}
}
let iters = {
let info = op
.input_mapping
.iter()
.find_map(|it| match it {
InputMapping::Scan(info) => Some(info),
_ => None,
})
.unwrap();
inputs[info.slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
};
let mut outputs = tvec!();
for (ix, output) in op.output_mapping.iter().enumerate() {
if let Some(info) = output.scan {
let fact = op.plan.model().output_fact(ix)?;
let mut shape: TVec<usize> =
fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
let scanning_dim = output
.full_dim_hint
.as_ref()
.and_then(|d| d.to_usize().ok())
.unwrap_or(shape[info.axis] * iters);
shape[info.axis] = scanning_dim;
let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
outputs.push((info.slot, t));
}
if let Some(slot) = output.last_value_slot {
outputs.push((slot, Tensor::default()));
}
}
outputs.sort_by_key(|a| a.0);
let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
for i in 0..iters {
*position += 1;
if *position <= op.skip {
continue;
}
hidden_state.reverse();
let iter_inputs: TVec<TValue> = op
.input_mapping
.iter()
.map(|m| {
Ok(match m {
InputMapping::State { .. } => Some(hidden_state.pop().unwrap()),
InputMapping::Scan(info) => Some(
Self::slice_input(&inputs[info.slot], info.axis, i, info.chunk)?
.into_tvalue(),
),
InputMapping::Full { slot } => Some(inputs[*slot].clone()),
})
})
.collect::<TractResult<Vec<_>>>()?
.into_iter()
.flatten()
.collect();
trace!("iter_inputs #{}: {:?}", i, iter_inputs);
let iter_outputs =
model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
trace!("iter_outputs #{}: {:?}", i, iter_outputs);
for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
if let Some(info) = mapping.scan {
Self::assign_output(&mut outputs[info.slot], info.axis, &v, i, info.chunk < 0);
}
if i == iters - 1 {
if let Some(slot) = mapping.last_value_slot {
outputs[slot] = v.clone().into_tensor();
}
}
if mapping.state {
hidden_state.push(v);
}
}
}
Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
}
}
impl TypedOp for LirScan {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut outputs = tvec!();
let iters = {
let info = self.input_mapping.iter().find_map(|it| it.as_scan()).unwrap();
inputs[info.slot].shape[info.axis].clone().div_ceil(info.chunk.unsigned_abs() as _)
};
for (ix, output) in self.output_mapping.iter().enumerate() {
let fact = self.plan.model().output_fact(ix)?;
if let Some(slot) = output.last_value_slot {
outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
}
if let Some(info) = output.scan {
let mut shape = fact.shape.clone();
let scanning_dim =
output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
shape.set(info.axis, scanning_dim);
outputs.push((info.slot, fact.datum_type.fact(shape)));
}
}
outputs.sort_by_key(|a| a.0);
let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
Ok(outputs)
}sourcepub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F>
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F>
Get the ix-th input tensor type information, mutably.
sourcepub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()>
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()>
Set the ix-th output tensor type information.
sourcepub fn with_output_fact(self, output: usize, fact: F) -> TractResult<Self>
pub fn with_output_fact(self, output: usize, fact: F) -> TractResult<Self>
Set the ix-th output tensor type information and return self.
sourcepub fn node_names(&self) -> impl Iterator<Item = &str>
pub fn node_names(&self) -> impl Iterator<Item = &str>
Iterate over all node names.
sourcepub fn node_id_by_name(&self, name: &str) -> TractResult<usize>
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize>
sourcepub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>>
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>>
Find a node by its name.
Examples found in repository?
More examples
162 163 164 165 166 167 168 169 170 171 172 173 174 175
pub fn set_input_names(
&mut self,
inputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut ids = vec![];
for i in inputs.into_iter() {
let node = self.node_by_name(&i)?;
for o in 0..node.outputs.len() {
ids.push(OutletId::new(node.id, o))
}
}
self.inputs = ids;
Ok(())
}sourcepub fn node_by_name_mut(
&mut self,
name: impl AsRef<str>
) -> TractResult<&mut Node<F, O>>
pub fn node_by_name_mut(
&mut self,
name: impl AsRef<str>
) -> TractResult<&mut Node<F, O>>
Borrow mutably a node by its name.
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()>
sourcepub fn node(&self, id: usize) -> &Node<F, O>
pub fn node(&self, id: usize) -> &Node<F, O>
Find a node by its id.
Examples found in repository?
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}More examples
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
fn declutter_body_axes(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut suggestions = vec![];
for n in self.body.eval_order()? {
let node = self.body.node(n);
for suggestion in node.op.suggested_axis_changes()? {
let outlet = suggestion.0.as_outlet(node);
suggestions.push(AxisChange { outlet, op: suggestion.1 })
}
}
for suggestion in suggestions.into_iter() {
if let Some(op) =
self.try_body_axes_change(suggestion, true)?.and_then(|c| c.substitute_op)
{
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
op,
)?));
}
}
Ok(None)
}
fn remove_outer_input_from_mappings(
mappings: &[InputMapping],
discarded: usize,
) -> Vec<InputMapping> {
mappings
.iter()
.map(|m| match m {
&InputMapping::Full { slot } => {
InputMapping::Full { slot: slot - (slot > discarded) as usize }
}
&InputMapping::Scan(info) => InputMapping::Scan(ScanInfo {
slot: info.slot - (info.slot > discarded) as usize,
..info
}),
InputMapping::State { initializer } => {
let initializer = match initializer {
StateInitializer::FromInput(n) => {
StateInitializer::FromInput(*n - (*n > discarded) as usize)
}
StateInitializer::Value(v) => StateInitializer::Value(v.clone()),
};
InputMapping::State { initializer }
}
})
.collect()
}
fn remove_outer_output_from_mappings(
mappings: &[OutputMapping<TDim>],
discarded: usize,
) -> Vec<OutputMapping<TDim>> {
mappings
.iter()
.map(|m| OutputMapping {
scan: m.scan.map(|info| ScanInfo {
slot: info.slot - (info.slot > discarded) as usize,
..info
}),
last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
full_dim_hint: m.full_dim_hint.clone(),
state: m.state,
})
.collect()
}
fn declutter_const_initializer(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
for (ix, mapping) in self.input_mapping.iter().enumerate() {
if let InputMapping::State { initializer: StateInitializer::FromInput(n) } = mapping {
if let Some(i) = inputs[*n].konst.as_ref() {
let mut op = self.clone();
op.input_mapping[ix] =
InputMapping::State { initializer: StateInitializer::Value(i.clone()) };
op.input_mapping =
Self::remove_outer_input_from_mappings(&op.input_mapping, *n);
let mut inputs = node.inputs.clone();
inputs.remove(*n);
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
}
}
}
Ok(None)
}
fn declutter_discard_unused_input_mapping(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
let source_node = self.body.node(input.node);
if source_node.outputs[0].successors.len() == 0
&& !self.body.output_outlets()?.contains(input)
{
let mut new_inputs = node.inputs.clone();
let slot = match &self.input_mapping[inner_input_id] {
InputMapping::Full { slot } => Some(*slot),
InputMapping::Scan(info) => Some(info.slot),
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(n) => Some(*n),
_ => None,
},
};
let mut new_mappings: Vec<_> = self.input_mapping.clone();
new_mappings.remove(inner_input_id);
if let Some(slot) = slot {
new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
}
let mut model_inputs = self.body.input_outlets()?.to_vec();
if let Some(slot) = slot {
new_inputs.remove(slot);
}
model_inputs.remove(inner_input_id);
let mut body = self.body.clone();
let mut patch = TypedModelPatch::default();
patch.obliterate(source_node.id)?;
patch.apply(&mut body)?;
body.set_input_outlets(&model_inputs)?;
body.declutter()?;
let op = Self {
body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
input_mapping: new_mappings,
decluttered: true,
output_mapping: self.output_mapping.clone(),
};
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
}
}
Ok(None)
}
fn declutter_discard_useless_outer_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, o) in node.outputs.iter().enumerate() {
if o.successors.len() == 0
&& !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
{
let mappings = self
.output_mapping
.iter()
.map(|m| OutputMapping {
scan: m.scan.filter(|info| info.slot != ix),
last_value_slot: m.last_value_slot.filter(|s| *s != ix),
full_dim_hint: m.full_dim_hint.clone(),
state: m.state,
})
.collect::<Vec<_>>();
let mut op = self.clone();
op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
let mut patch = TypedModelPatch::default();
let inputs = node
.inputs
.iter()
.map(|&i| patch.tap_model(model, i))
.collect::<TractResult<Vec<_>>>()?;
let wires = patch.wire_node(&*node.name, op, &inputs)?;
for oix in 0..node.outputs.len() {
if oix != ix {
patch.shunt_outside(
model,
OutletId::new(node.id, oix),
wires[oix - (oix > ix) as usize],
)?;
}
}
return Ok(Some(patch));
}
}
Ok(None)
}
fn declutter_discard_empty_output_mapping_with_body_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, om) in self.output_mapping.iter().enumerate() {
if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
let mut new_op = self.clone();
new_op.output_mapping.remove(ix);
new_op.body.outputs.remove(ix);
new_op.decluttered = false;
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
new_op,
)?));
}
}
Ok(None)
}
fn declutter_pull_batcheable_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_input, input) in self.input_mapping.iter().enumerate() {
if let Some(info) = input.as_scan() {
let scan_source = self.body.input_outlets()?[model_input];
let scan_source_node = self.body.node(scan_source.node);
for successor in &scan_source_node.outputs[0].successors {
let successor_node = self.body.node(successor.node);
if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
continue;
}
let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
let mut outside_patch = TypedModelPatch::new(format!(
"Outer patch for input extraction of {}",
successor_node
));
let mut patch_inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let input = patch_inputs[info.slot];
let new_input_wire = outside_patch.wire_node(
format!("{}.extracted.{}", node.name, successor_node.name),
successor_node.op.clone(),
&[input],
)?[0];
patch_inputs.push(new_input_wire);
let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
let mut new_input_inner_fact = new_input_outer_fact.clone();
new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());
let mut new_body = self.body.clone();
let new_source_wire = new_body.add_source(
format!("{}.extracted.{}", node.name, successor_node.name),
new_input_inner_fact,
)?;
let mut inner_patch = TypedModelPatch::new(format!(
"Inner body patch for extraction of {}",
successor_node
));
let new_source_wire_in_patch =
inner_patch.tap_model(&new_body, new_source_wire)?;
inner_patch
.shunt_outside(
&new_body,
OutletId::new(successor.node, 0),
new_source_wire_in_patch,
)
.with_context(|| "patching inner model")?;
inner_patch.apply(&mut new_body)?;
let mut input_mapping = self.input_mapping.clone();
input_mapping.push(InputMapping::Scan(ScanInfo {
axis: axis_after,
chunk: info.chunk,
slot: node.inputs.len(),
}));
let new_op = Self {
input_mapping,
output_mapping: self.output_mapping.clone(),
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let output_wires =
outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
for w in output_wires {
outside_patch
.shunt_outside(model, OutletId::new(node.id, w.slot), w)
.with_context(|| "patching outer model")?;
}
return Ok(Some(outside_patch));
}
}
}
}
Ok(None)
}
fn declutter_pull_constant_outputs(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(slot) = mapping.last_value_slot {
if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
let inner_node = self.body.output_outlets()?[model_output_ix].node;
let inner_node = self.body.node(inner_node);
let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}
fn declutter_pull_batcheable_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(info) = mapping.scan {
let emitter_outlet = self.body.output_outlets()?[model_ix];
let emitter_node = self.body.node(emitter_outlet.node);
if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
|| mapping.state
|| mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
{
// continue if both last_value and full values are exported
continue;
}
let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
else {
continue;
};
let mut new_body = self.body.clone();
let mut new_output_mapping = self.output_mapping.clone();
let mut new_scan_outputs = node.outputs.len();
let mut outer_slots = vec![];
for input in &emitter_node.inputs {
if new_body.outputs.iter().all(|o| o != input) {
new_output_mapping.push(OutputMapping::default());
new_body.outputs.push(*input);
}
let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
let mut mapping = &mut new_output_mapping[body_output_id];
let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
if mapping.last_value_slot.is_none() {
mapping.last_value_slot = Some(new_scan_outputs);
}
new_scan_outputs += 1;
mapping.last_value_slot.unwrap()
} else {
if mapping.scan.is_none() {
mapping.scan = Some(ScanInfo {
slot: new_scan_outputs,
axis: axis_before,
chunk: info.chunk,
});
new_scan_outputs += 1;
}
mapping.scan.unwrap().slot
};
outer_slots.push(outer_slot);
}
let mut outside_patch = TypedModelPatch::new(format!(
"Outside patch for output extraction of {}",
emitter_node
));
let inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let new_op = Self {
input_mapping: self.input_mapping.clone(),
output_mapping: new_output_mapping,
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
let output = mapping.scan.unwrap();
let inputs =
outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
let wire = outside_patch.wire_node(
&*emitter_node.name,
emitter_node.op.clone(),
&inputs,
)?[0];
outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
for output_slot in 0..node.outputs.len() {
if output_slot != output.slot {
outside_patch.shunt_outside(
model,
OutletId::new(node.id, output_slot),
OutletId::new(scan_outputs[0].node, output_slot),
)?;
}
}
return Ok(Some(outside_patch));
}
}
Ok(None)
}24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
if self.0.insert((n, suggestion.clone())) {
let outlet = suggestion.0.as_outlet(model.node(n));
let change = AxisChange { outlet, op: suggestion.1.clone() };
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
}
Ok(None)
}13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
fn next(&mut self, _session: &mut OptimizerSession, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for node in model.eval_order()? {
for output in &model.node(node).outputs {
for (a, b) in output.successors.iter().tuple_combinations() {
if patch.obliterate.contains(&b.node) {
continue;
}
let a = model.node(a.node);
let b = model.node(b.node);
if a.same_as(b) {
for slot in 0..b.outputs.len() {
let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
patch.obliterate(b.id)?;
}
}
}
}
}
Ok(Some(patch).filter(|p| !p.is_empty()))
}14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for n in model.eval_order()? {
let node = model.node(n);
if node.op.is_stateless() && !node.op_is::<Const>() {
if let Some(inputs) = model
.node_input_facts(n)?
.iter()
.map(|f| f.konst.clone().map(|t| t.into_tvalue()))
.collect()
{
match node.op.eval(inputs) {
Ok(res) => {
for (ix, output) in res.into_iter().enumerate() {
let mut name = node.name.clone();
if ix > 0 {
name = format!("{}.{}", name, ix);
}
let wire = patch.add_const(name, output.into_arc_tensor())?;
patch.shunt_outside(model, (n, ix).into(), wire)?;
}
}
Err(e) => {
if !e.root_cause().is::<UndeterminedSymbol>() {
Err(e).with_context(|| {
format!("Eager eval {} during optimisation", model.node(n))
})?;
}
}
}
}
}
}
Ok(Some(patch).filter(|p| p.nodes.len() > 0))
}sourcepub fn node_mut(&mut self, id: usize) -> &mut Node<F, O>
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O>
Find a node by its id.
Examples found in repository?
More examples
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
let prior_target_inputs = target.input_outlets()?.len();
let prior_target_outputs = target.output_outlets()?.len();
let ModelPatch {
model: patch,
incoming: mut mapping,
shunt_outlet_by,
obliterate,
inputs: replaced_inputs,
..
} = self;
let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
let mut model_input_outlets = target.input_outlets()?.to_vec();
for node in patch.nodes {
if <Graph<F, O>>::is_source(&node.op)
&& mapping.contains_key(&OutletId::new(node.id, 0))
{
// this is a tap
continue;
}
let Node { id: patch_node_id, name, inputs, op, outputs } = node;
let n_outputs = outputs.len();
for dup in 0..target.nodes.len() {
if target.node(dup).op().same_as(op.as_ref())
&& inputs.len() == target.node(dup).inputs.len()
&& inputs
.iter()
.zip(target.node(dup).inputs.iter())
.all(|(patch_input, d)| mapping[patch_input] == *d)
{
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
}
continue;
}
}
let facts = outputs.into_iter().map(|of| of.fact).collect();
let added_node_id = target.add_node(name, op, facts)?;
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
}
all_inputs.insert(added_node_id, inputs);
if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
// this is actually an input replacement
model_input_outlets.iter_mut().for_each(|oo| {
if oo.node == replaced_inputs[&patch_node_id] {
oo.node = added_node_id;
}
});
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (outlet, by) in shunt_outlet_by {
let replace_by = mapping[&by];
let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
for succ in succs {
target.add_edge(replace_by, succ)?;
}
for o in target.outputs.iter_mut() {
if *o == outlet {
*o = replace_by;
}
}
if let Some(label) = target.outlet_labels.remove(&outlet) {
target.set_outlet_label(replace_by, label)?;
}
}
if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
bail!("Duplicate usage of node as output");
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (node, inputs) in all_inputs {
for (ix, input) in inputs.into_iter().enumerate() {
target.add_edge(mapping[&input], InletId::new(node, ix))?;
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for node in obliterate {
target.node_mut(node).op = target.create_dummy();
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
target.set_input_outlets(&model_input_outlets)?;
Ok(())
}6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
pub fn pull_downsample_over_scan(
model: &TypedModel,
scan_node: &TypedNode,
scan_op: &ops::scan::Scan,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.stride < 0 {
return Ok(None);
}
// introduce downsample at end of body
let mut downsampled_body = scan_op.body.clone();
downsampled_body.check_consistency()?;
let outputs = downsampled_body.output_outlets()?.to_owned();
let downsample_outputs = outputs
.into_iter()
.enumerate()
.map(|(ix, oo)| {
Ok(downsampled_body.wire_node(
format!("{}-{}", &down_node.name, ix),
down_op.clone(),
&[oo],
)?[0])
})
.collect::<TractResult<Vec<_>>>()?;
downsampled_body.set_output_outlets(&downsample_outputs)?;
downsampled_body.declutter()?;
downsampled_body.check_consistency()?;
// check if downsample ops introduced at end have swimmed up to scan inputs during declutter
for input in downsampled_body.input_outlets()? {
let input = downsampled_body.node(input.node);
if input.outputs[0]
.successors
.iter()
.any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
{
return Ok(None);
}
}
let inputs = downsampled_body.input_outlets()?.to_vec();
for input in inputs {
let node = &mut downsampled_body.node_mut(input.node);
let fact = &mut node.outputs[0].fact;
*fact = down_op.transform_fact(fact)?;
node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
for ds in downsamples {
TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
.apply(&mut downsampled_body)?;
}
}
downsampled_body.check_consistency()?;
let inner_model = downsampled_body.into_decluttered()?;
let mut new_scan = scan_op.clone();
new_scan.body = inner_model;
for input in &mut new_scan.input_mapping {
match input {
InputMapping::State { ref mut initializer } => {
if let StateInitializer::Value(ref v) = initializer {
let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
*initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
}
}
InputMapping::Scan(info) => {
if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
_ => (),
}
}
for output in &mut new_scan.output_mapping {
if let Some(d) = output.full_dim_hint.as_mut() {
*d = down_op.transform_dim(d)
}
if let Some(info) = &mut output.scan {
if info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
}
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for (ix, &i) in scan_node.inputs.iter().enumerate() {
let tap = patch.tap_model(model, i)?;
let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
inputs.push(ds);
}
let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
for ix in 0..scan_node.outputs.len() {
// FIXME need to check earlier on that all output are followed by a ds
let succ = scan_node.outputs[ix].successors[0].node;
patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
}
Ok(Some(patch))
}sourcepub fn nodes(&self) -> &[Node<F, O>]
pub fn nodes(&self) -> &[Node<F, O>]
Access the nodes table.
Examples found in repository?
More examples
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
fn full_pass(
&mut self,
session: &mut OptimizerSession,
new: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, &id) in new.eval_order()?.iter().enumerate().skip(self.2) {
let node = &new.nodes()[id];
let patch = (self.1)(node.op.as_ref(), session, new, node)
.with_context(|| format!("{:?} node {}", self, node))?;
if let Some(mut p) = patch {
p.push_context(format!("{:?} {}", self, node));
self.2 = ix + p.dont_apply_twice.is_some() as usize;
return Ok(Some(p));
}
}
Ok(None)
}247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut labels: HashMap<Cow<str>, OutletId> =
self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
for n in self.nodes() {
for ix in 0..n.outputs.len() {
labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
}
}
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| {
let s = s.as_ref();
labels
.get(s)
.cloned()
.or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
.ok_or_else(|| format_err!("Node {} not found", s))
})
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
/// Set model outputs by node names and return `self`.
pub fn with_output_names(
mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_output_names(outputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
/// Set the `ix`-th output tensor type information.
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th output tensor type information and return `self`.
pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
self.set_output_fact(output, fact)?;
Ok(self)
}
// nodes and their facts
/// Iterate over all node names.
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{}\"", name))
}
/// Find a node by its name.
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
/// Borrow mutably a node by its name.
pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&mut self.nodes[id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
Ok(())
}
/// Find a node by its id.
pub fn node(&self, id: usize) -> &Node<F, O> {
&self.nodes[id]
}
/// Find a node by its id.
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
&mut self.nodes[id]
}
/// Access the nodes table.
pub fn nodes(&self) -> &[Node<F, O>] {
&self.nodes
}
/// Access the nodes table.
pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
&mut self.nodes
}
/// Get input and output tensor information for a node.
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}
// outlet labels
/// Get label for an outlet.
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
/// Set label for an outlet.
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
self.outlet_labels.insert(outlet, label);
Ok(())
}
/// Set label for an outlet and return `self`.
pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
self.set_outlet_label(outlet, label)?;
Ok(self)
}
/// Find outlet by label.
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
}
// misc
/// Computes an evalutation order for the graph inputs and outputs
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
Ok(())
}
/// Performs a sanity check on network connections.
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
pub fn new_for_outputs_and_deps(
model: M,
outputs: &[OutletId],
deps: &[(usize, usize)],
) -> TractResult<SimplePlan<F, O, M>> {
let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
let order = eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
for (step, node) in order.iter().enumerate() {
for i in &model.borrow().node(*node).inputs {
values_needed_until_step[i.node] = step;
}
}
for o in outputs.iter() {
values_needed_until_step[o.node] = order.len();
}
let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
if flush_at != 0 {
flush_lists[flush_at].push(node)
}
}
let mut symbols: std::collections::HashSet<Symbol> = Default::default();
for node in &model.borrow().nodes {
for output in &node.outputs {
if let Ok(fact) = output.fact.to_typed_fact() {
symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
}
}
}
Ok(SimplePlan {
model,
order,
flush_lists,
outputs: outputs.to_vec(),
has_unresolved_symbols: !symbols.is_empty(),
_casper: PhantomData,
})
}
pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut state = SimpleState::new(self)?;
state.run(inputs)
}
pub fn model(&self) -> &Graph<F, O> {
self.model.borrow()
}
}
#[derive(Clone, Debug)]
pub struct SimpleState<F, O, M, P>
where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
M: Borrow<Graph<F, O>> + Hash,
P: Borrow<SimplePlan<F, O, M>>,
{
plan: P,
pub states: Vec<Option<Box<dyn OpState>>>,
pub session_state: SessionState,
pub values: Vec<Option<TVec<TValue>>>,
_phantom: PhantomData<(M, F, O)>,
}
impl<F, O, M, P> SimpleState<F, O, M, P>
where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
M: Borrow<Graph<F, O>> + Hash,
P: Borrow<SimplePlan<F, O, M>> + Clone,
{
pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
let values = vec![None; plan.borrow().model.borrow().nodes().len()];
let mut session = SessionState::default();
let model = plan.borrow().model();
let states: Vec<Option<Box<dyn OpState>>> = model
.nodes()
.iter()
.map(|n: &Node<F, O>| n.op().state(&mut session, n.id))
.collect::<TractResult<_>>()?;
Ok(SimpleState { plan, states, session_state: session, values, _phantom: PhantomData })
}
/// Reset wires state.
pub fn reset_turn(&mut self) -> TractResult<()> {
self.values.iter_mut().for_each(|s| *s = None);
Ok(())
}
/// Reset op inner state.
pub fn reset_op_states(&mut self) -> TractResult<()> {
let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
*states = plan
.borrow()
.model()
.nodes()
.iter()
.map(|n| n.op().state(session_state, n.id))
.collect::<TractResult<_>>()?;
Ok(())
}
pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
self.run_plan_with_eval(inputs, self::eval)
}
pub fn exec(&mut self) -> TractResult<()> {
self.exec_plan_with_eval(self::eval)
}
pub fn run_plan_with_eval<Eval, E>(
&mut self,
inputs: TVec<TValue>,
eval: Eval,
) -> TractResult<TVec<TValue>>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
self.set_inputs(inputs)?;
self.exec_plan_with_eval(eval)?;
let outputs = self.outputs()?;
self.reset_turn()?;
Ok(outputs)
}
pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
{
let &mut SimpleState {
ref plan,
ref mut session_state,
ref mut states,
ref mut values,
..
} = self;
let plan = plan.borrow();
let model = plan.model().borrow();
for (step, n) in plan.order.iter().enumerate() {
let node = model.node(*n);
trace!("Running step {}, node {}", step, node);
let mut inputs: TVec<TValue> = tvec![];
for i in &node.inputs {
trace!(" use input {:?}", i);
let prec_node = model.node(i.node);
let prec = values[i.node].as_ref().ok_or_else(|| {
format_err!("Computing {}, precursor {} not done:", node, prec_node)
})?;
inputs.push(prec[i.slot].clone())
}
for flush in &plan.flush_lists[step] {
trace!(" Ran {} can now flush {}", node, model.node(*flush));
values[*flush] = None;
}
if cfg!(debug_assertions) {
let facts = model.node_input_facts(node.id)?;
if facts.len() != inputs.len() {
bail!(
"Evaluating {}: expected {} inputs, got {}",
node,
facts.len(),
inputs.len()
);
}
for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: input {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
.map_err(|e| e.into())?;
if plan.has_unresolved_symbols {
for (o, v) in node.outputs.iter().zip(vs.iter()) {
if let Ok(f) = o.fact.to_typed_fact() {
for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
Self::resolve(
&mut session_state.resolved_symbols,
&dim_abstract,
*dim_concrete as i64,
);
}
}
}
}
if cfg!(debug_assertions) {
let facts = model.node_output_facts(node.id)?;
if facts.len() != vs.len() {
bail!(
"Evaluating {}: expected {} outputs, got {}",
node,
facts.len(),
vs.len()
);
}
for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
if node.outputs[ix].successors.len() == 0 {
continue;
}
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: output {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
values[node.id] = Some(vs);
}
}
Ok(())
}
pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
ensure!(
inputs.len() == self.model().inputs.len(),
"Wrong number of inputs for model. Expected {} got {}",
self.model().inputs.len(),
inputs.len()
);
for (ix, t) in inputs.into_iter().enumerate() {
self.set_input(ix, t)?
}
Ok(())
}
fn resolve(symbols: &mut SymbolValues, expected: &TDim, provided: i64) {
match expected {
TDim::Sym(s) => symbols[s] = Some(provided),
TDim::MulInt(x, expr) => Self::resolve(symbols, expr, provided / *x),
_ => (),
}
}
pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
let outlet: OutletId = *self
.model()
.input_outlets()?
.get(input)
.ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?;
let SimpleState { plan, session_state, .. } = self;
let plan = (*plan).borrow();
let model = plan.model.borrow();
if let Ok(fact) = model.outlet_fact(outlet)?.to_typed_fact() {
for (expected, provided) in fact.shape.iter().zip(t.shape()) {
Self::resolve(&mut session_state.resolved_symbols, &expected, *provided as i64)
}
}
let fact = self.plan.borrow().model().outlet_fact(outlet)?;
ensure!(
fact.matches(&t, Some(&self.session_state.resolved_symbols))
.with_context(|| format!("Setting input {}", input))?,
"Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})",
input,
t.shape(),
t.datum_type(),
fact
);
self.session_state.inputs.insert(outlet.node, t);
Ok(())
}
pub fn output(&self, id: usize) -> TractResult<&TValue> {
let outlet = self.model().output_outlets()?.get(id).with_context(|| {
format!(
"Required output {}, only have {}",
id,
self.model().output_outlets().unwrap().len()
)
})?;
let value: &TValue = self
.values
.get(outlet.node)
.context("node id for output beyond node values array")?
.as_ref()
.context("node is not an output")?
.get(outlet.slot)
.context("slot id too high")?;
Ok(value)
}
pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
let SimpleState { ref plan, ref mut values, .. } = self;
let mut v = tvec![];
for o in plan.borrow().outputs.iter() {
let vs = values[o.node].as_mut().ok_or_else(|| {
format_err!(
"Outputs of {:?} are not computed",
&plan.borrow().model().nodes()[o.node]
)
})?;
v.push(vs[o.slot].clone())
}
Ok(v)
}
pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
self.values[id] = Some(values);
Ok(())
}
pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
self.set_values(id, tvec!(value))
}
pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
let SimpleState { ref plan, ref values, .. } = self;
let plan = plan.borrow();
let nodes = plan.model().nodes();
let node = &nodes[node];
let mut inputs: TVec<TValue> = tvec![];
for i in &node.inputs {
let prec_node = &nodes[i.node];
let prec = values[i.node].as_ref().ok_or_else(|| {
format_err!("Computing {}, precursor {} not done.", node, prec_node)
})?;
inputs.push(prec[i.slot].clone())
}
Ok(inputs)
}
pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
let inputs = self.prepare_inputs(node)?;
self.compute_one_with_inputs(node, inputs)
}
pub fn compute_one_with_inputs(
&mut self,
node: usize,
inputs: TVec<TValue>,
) -> TractResult<()> {
let SimpleState { ref plan, ref mut session_state, ref mut values, .. } = self;
let plan = plan.borrow();
let nodes = plan.model().nodes();
let node = &nodes[node];
let vs = match self.states[node.id] {
Some(ref mut state) => state.eval(session_state, node.op(), inputs),
None => node.op().eval(inputs),
}
.with_context(|| format!("Evaluating {}", node))?;
values[node.id] = Some(vs);
Ok(())
}
pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
let values = {
#[allow(clippy::needless_collect)] // clippy bug ?
let precs: Vec<usize> =
self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
for i in precs.into_iter() {
if self.values[i].is_none() {
let _ = self.compute_recursively(i)?;
}
}
let mut inputs: TVec<TValue> = tvec![];
{
let node = &self.model().nodes()[node];
for i in &node.inputs {
inputs.push(self.values[i.node].as_ref().unwrap()[i.slot].clone())
}
}
let Self { ref mut states, ref mut session_state, ref plan, .. } = self;
let plan = plan.borrow();
match states[node] {
Some(ref mut state) => {
state.eval(session_state, plan.borrow().model().nodes()[node].op(), inputs)
}
None => plan.borrow().model().nodes()[node].op().eval(inputs),
}
.with_context(|| format!("Evaluating {:?}", node))?
};
self.values[node] = Some(values);
Ok(self.values[node].as_ref().unwrap())
}229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
fn declutter_precusor_is_concat(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
{
let mut patch = TypedModelPatch::new("split over k-concatenated input");
if concat.axis == self.axes.b_k {
let concat_node = model.node(node.inputs[0].node);
let offsets = concat
.offsets(&model.node_input_facts(concat_node.id)?)?
.iter()
.map(|x| x.to_usize())
.collect::<TractResult<Vec<usize>>>()?;
let mut wires = vec![];
for (ix, input) in concat_node.inputs.iter().enumerate() {
let wire = patch.tap_model(model, *input)?;
let a = self.a.slice(self.axes.a_k, offsets[ix], offsets[ix + 1])?;
let wire = patch.wire_node(
format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
MatMulUnary { a: a.into_arc_tensor(), ..self.clone() },
&[wire],
)?[0];
wires.push(wire)
}
let mut wire = wires[0];
for (ix, w) in wires[1..].iter().enumerate() {
wire = patch.wire_node(
format!("{}.k-add-{}", node.name, ix),
crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
&[wire, *w],
)?[0];
}
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
return Ok(Some(patch));
}
}
Ok(None)
}157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
use crate::ops::array::TypedConcat;
if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
{
let mut patch = TypedModelPatch::new("split over k-concatenated input");
let k_axis = self.axes.a_k;
if concat.axis == self.axes.b_k {
let concat_node = model.node(node.inputs[0].node);
let offsets = concat
.offsets(&model.node_input_facts(concat_node.id)?)?
.iter()
.map(|x| x.to_usize())
.collect::<TractResult<Vec<usize>>>()?;
let mut wires = vec![];
let mut params_for_split = self.params.clone();
params_for_split.a_scale = tensor0(1.0f32).into();
params_for_split.b_scale = tensor0(1.0f32).into();
params_for_split.c_scale = tensor0(1.0f32).into();
params_for_split.c0 = tensor0(0i32).into();
let input_outlets = node
.inputs
.iter()
.skip(1)
.map(|o| patch.tap_model(model, *o))
.collect::<TractResult<TVec<_>>>()?;
let params_outlets = self.params.as_outlet_ids(
&mut patch,
&node.name,
&input_outlets,
self.a.datum_type(),
model.node_input_facts(node.id)?[0].datum_type,
self.output_type,
)?;
let scale = combine_scales(
&mut patch,
&node.name,
params_outlets[1],
params_outlets[3],
params_outlets[5],
)?;
let c0 = params_outlets[4];
for (ix, input) in concat_node.inputs.iter().enumerate() {
let wire = patch.tap_model(model, *input)?;
let a = self.a.slice(k_axis, offsets[ix], offsets[ix + 1])?;
let wire = patch
.wire_node(
format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
Self {
a: a.into_arc_tensor(),
output_type: DatumType::I32,
bias: self.bias.clone().filter(|_| ix == 0),
params: params_for_split.clone(),
..self.clone()
},
&[wire],
)
.context("wiring new matmulunary")?[0];
wires.push(wire)
}
let mut wire = wires[0];
for (ix, w) in wires[1..].iter().enumerate() {
wire = patch.wire_node(
format!("{}.k-add-{}", node.name, ix),
crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
&[wire, *w],
)?[0];
}
wire = requant(&mut patch, &node.name, wire, self.output_type, scale, c0)?;
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
return Ok(Some(patch));
}
}
Ok(None)
}sourcepub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)>
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)>
Get input and output tensor information for a node.
Examples found in repository?
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
fn pull_downsample_up(
model: &TypedModel,
down_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
model.check_consistency()?;
let down_op = down_node.op_as::<Downsample>().unwrap();
if let Some(prec) = model.single_prec(down_node.id)? {
let (input_facts, output_facts) = model.node_facts(prec.id)?;
let invariants = prec.op.invariants(&input_facts, &output_facts)?;
debug!("Consider pull {:?} over {:?} (invariants: {:?})", down_op, prec, invariants);
if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
if let Some(p) = array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)? {
return Ok(Some(p))
}
} else if let Some(other_op) = prec.op_as::<AxisOp>() {
return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
} else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::ConvUnary>() {
return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
} else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
}
if let Some(above_axis) = invariants.unary_track_axis_up(down_op.axis, false) {
let mut patch = TypedModelPatch::default();
let mut inputs = vec![];
for (ix, &oo) in prec.inputs.iter().enumerate() {
let source = patch.tap_model(model, oo)?;
let mut op = down_op.clone();
op.axis = above_axis;
let ds = patch.wire_node(
format!("{}.{}-{}", down_node.name, prec.name, ix),
op,
[source].as_ref(),
)?;
inputs.push(ds[0]);
}
let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
return Ok(Some(patch));
}
}
Ok(None)
}More examples
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
for n in model.eval_order()? {
let (ifacts, ofacts) = model.node_facts(n)?;
if ofacts.len() != 1 {
continue;
}
let node = model.node(n);
let invariants = node.op.invariants(&ifacts, &ofacts)?;
'axis: for axis in 0..ofacts[0].rank() {
if let Some(boundaries) = should_slice_output(model, node, axis)? {
let mut splits = tvec!();
let mut patch = TypedModelPatch::new("push slice up");
let inputs = node
.inputs
.iter()
.map(|i| patch.tap_model(model, *i))
.collect::<TractResult<TVec<OutletId>>>()?;
let mut start = 0;
let axis_info = invariants.track_output_axis(0, axis);
for end in &boundaries {
let mut wires = tvec!();
for input_ix in 0..inputs.len() {
let mut wire = inputs[input_ix];
if let Some(input_axis) = axis_info.and_then(|it| it.inputs[input_ix]) {
wire = patch.wire_node(
format!(
"{}.split-{}-over-{}.{}..{}.slice",
&node.name, input_ix, input_axis, start, end
),
Slice {
axis: input_axis,
start: start.to_dim(),
end: end.to_dim(),
},
&[wire],
)?[0];
}
wires.push(wire);
}
let Some(wire) = node.op.slice(
&mut patch,
&format!(
"{}.split-over-{}.{}..{}",
&node.name, axis, start, end
),
&wires,
axis,
start,
*end,
)? else {
continue 'axis };
splits.push(wire[0]);
start = *end;
}
rewire_sliced_outputs(model, node, axis, &mut patch, &boundaries, &splits)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
pub fn for_outlet_and_axis(
model: &TypedModel,
outlet: OutletId,
axis: usize,
) -> TractResult<AxisTracking> {
let mut mapped_outlets = OutletMap::default();
let mut todo = OutletMap::default();
let mut disposable = true;
let mut creators = tvec!();
let mut destructors = tvec!();
mapped_outlets.insert(outlet, axis);
todo.insert(outlet, ());
while let Some(wire) = todo.keys().next() {
todo.remove(&wire);
let axis = mapped_outlets[&wire];
let emiter_node = model.node(wire.node);
let mut nodes = vec![];
let (input_facts, output_facts) = model.node_facts(emiter_node.id)?;
let invs = emiter_node
.op
.invariants(&input_facts, &output_facts)
.with_context(|| format!("Computing invariants for {}", emiter_node))?;
assert!(invs.axes.iter().all(|axis| axis.inputs.len() == emiter_node.inputs.len()));
assert!(invs.axes.iter().all(|axis| axis.outputs.len() == emiter_node.outputs.len()));
if let Some(info) = invs.track_output_axis(wire.slot, axis) {
nodes.push((wire.node, info.clone()));
} else {
creators.push(wire);
};
for succ in &emiter_node.outputs[wire.slot].successors {
let succ_node = model.node(succ.node);
let (input_facts, output_facts) = model.node_facts(succ_node.id)?;
let invs = succ_node.op.invariants(&input_facts, &output_facts)?;
assert!(invs.axes.iter().all(|axis| axis.inputs.len() == succ_node.inputs.len()));
assert!(invs.axes.iter().all(|axis| axis.outputs.len() == succ_node.outputs.len()));
if let Some(info) = invs.track_input_axis(succ.slot, axis) {
nodes.push((succ_node.id, info.clone()));
} else {
destructors.push(*succ);
};
}
let mut new_outlets = vec![];
for (n, axes) in nodes {
disposable = disposable && axes.disposable;
let node = model.node(n);
for slot in 0..node.outputs.len() {
if let Some(axis) = axes.outputs[slot] {
new_outlets.push((OutletId::new(n, slot), axis));
}
}
for slot in 0..node.inputs.len() {
if let Some(axis) = axes.inputs[slot] {
new_outlets.push((node.inputs[slot], axis));
}
}
}
for (outlet, axis) in new_outlets {
if let Some(prev) = mapped_outlets.get(&outlet) {
if *prev != axis {
bail!("Inconsistent network");
}
} else {
mapped_outlets.insert(outlet, axis);
todo.insert(outlet, ());
}
}
}
Ok(AxisTracking { creators, destructors, outlets: mapped_outlets, disposable })
}293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
fn declutter_pull_batcheable_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_input, input) in self.input_mapping.iter().enumerate() {
if let Some(info) = input.as_scan() {
let scan_source = self.body.input_outlets()?[model_input];
let scan_source_node = self.body.node(scan_source.node);
for successor in &scan_source_node.outputs[0].successors {
let successor_node = self.body.node(successor.node);
if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
continue;
}
let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
let mut outside_patch = TypedModelPatch::new(format!(
"Outer patch for input extraction of {}",
successor_node
));
let mut patch_inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let input = patch_inputs[info.slot];
let new_input_wire = outside_patch.wire_node(
format!("{}.extracted.{}", node.name, successor_node.name),
successor_node.op.clone(),
&[input],
)?[0];
patch_inputs.push(new_input_wire);
let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
let mut new_input_inner_fact = new_input_outer_fact.clone();
new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());
let mut new_body = self.body.clone();
let new_source_wire = new_body.add_source(
format!("{}.extracted.{}", node.name, successor_node.name),
new_input_inner_fact,
)?;
let mut inner_patch = TypedModelPatch::new(format!(
"Inner body patch for extraction of {}",
successor_node
));
let new_source_wire_in_patch =
inner_patch.tap_model(&new_body, new_source_wire)?;
inner_patch
.shunt_outside(
&new_body,
OutletId::new(successor.node, 0),
new_source_wire_in_patch,
)
.with_context(|| "patching inner model")?;
inner_patch.apply(&mut new_body)?;
let mut input_mapping = self.input_mapping.clone();
input_mapping.push(InputMapping::Scan(ScanInfo {
axis: axis_after,
chunk: info.chunk,
slot: node.inputs.len(),
}));
let new_op = Self {
input_mapping,
output_mapping: self.output_mapping.clone(),
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let output_wires =
outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
for w in output_wires {
outside_patch
.shunt_outside(model, OutletId::new(node.id, w.slot), w)
.with_context(|| "patching outer model")?;
}
return Ok(Some(outside_patch));
}
}
}
}
Ok(None)
}
fn declutter_pull_constant_outputs(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(slot) = mapping.last_value_slot {
if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
let inner_node = self.body.output_outlets()?[model_output_ix].node;
let inner_node = self.body.node(inner_node);
let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}
fn declutter_pull_batcheable_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(info) = mapping.scan {
let emitter_outlet = self.body.output_outlets()?[model_ix];
let emitter_node = self.body.node(emitter_outlet.node);
if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
|| mapping.state
|| mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
{
// continue if both last_value and full values are exported
continue;
}
let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
else {
continue;
};
let mut new_body = self.body.clone();
let mut new_output_mapping = self.output_mapping.clone();
let mut new_scan_outputs = node.outputs.len();
let mut outer_slots = vec![];
for input in &emitter_node.inputs {
if new_body.outputs.iter().all(|o| o != input) {
new_output_mapping.push(OutputMapping::default());
new_body.outputs.push(*input);
}
let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
let mut mapping = &mut new_output_mapping[body_output_id];
let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
if mapping.last_value_slot.is_none() {
mapping.last_value_slot = Some(new_scan_outputs);
}
new_scan_outputs += 1;
mapping.last_value_slot.unwrap()
} else {
if mapping.scan.is_none() {
mapping.scan = Some(ScanInfo {
slot: new_scan_outputs,
axis: axis_before,
chunk: info.chunk,
});
new_scan_outputs += 1;
}
mapping.scan.unwrap().slot
};
outer_slots.push(outer_slot);
}
let mut outside_patch = TypedModelPatch::new(format!(
"Outside patch for output extraction of {}",
emitter_node
));
let inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let new_op = Self {
input_mapping: self.input_mapping.clone(),
output_mapping: new_output_mapping,
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
let output = mapping.scan.unwrap();
let inputs =
outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
let wire = outside_patch.wire_node(
&*emitter_node.name,
emitter_node.op.clone(),
&inputs,
)?[0];
outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
for output_slot in 0..node.outputs.len() {
if output_slot != output.slot {
outside_patch.shunt_outside(
model,
OutletId::new(node.id, output_slot),
OutletId::new(scan_outputs[0].node, output_slot),
)?;
}
}
return Ok(Some(outside_patch));
}
}
Ok(None)
}146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
fn declutter(
&self,
model: &TypedModel,
dequant: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut current = dequant;
let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
while let Some(quant) = model.single_succ(current.id)? {
let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
} else {
op.0.downcast_ref::<QuantizeLinearI8>()
.map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
}
} else {
None
};
if let Some((scale, zero_point, dt)) = q_params {
// first, try Op::quantize() on all ops in the chain
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
let mut next = model.single_succ(dequant.id)?.unwrap();
loop {
if let Some(op) = next
.op
.quantize(model, dequant, dt, scale, zero_point)
.with_context(|| format!("Quantizing {}", next))?
{
wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
} else {
break;
}
if next.id == current.id {
patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
return Ok(Some(patch));
} else {
next = model.single_succ(next.id)?.unwrap();
}
}
// or else make a lookup table
if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
let mut adhoc_model = TypedModel::default();
let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
let mut next = model.single_succ(dequant.id)?.unwrap();
let mut name = None;
// plug in dequant
wire = adhoc_model.wire_node(
&*dequant.name,
dequant.op.clone(),
[wire].as_ref(),
)?[0];
while next.id != quant.id {
name.get_or_insert(&*next.name);
wire =
adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
[0];
next = model.single_succ(next.id)?.unwrap();
}
// plug in quant
wire =
adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
adhoc_model.set_output_outlets(&[wire])?;
let input = (0u8..=255).collect::<Vec<u8>>();
let input = match dt {
DatumType::I8 => unsafe {
tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
},
DatumType::U8 => tensor1(&input),
_ => unreachable!(),
};
let output =
SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
let table: &[u8] = match dt {
DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
DatumType::U8 => output.as_slice::<u8>()?,
_ => unreachable!(),
};
let op = lookup_table((tract_linalg::ops().lut_u8)(table));
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
return Ok(Some(patch));
}
}
let (input_facts, output_facts) = model.node_facts(quant.id)?;
let invariants = quant
.op
.invariants(&input_facts, &output_facts)
.with_context(|| format!("Querying invariants for {}", quant))?;
if invariants.element_wise() {
current = quant;
} else {
break;
}
}
Ok(None)
}sourcepub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>>
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>>
Get input tensor information for a node.
Examples found in repository?
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}
// outlet labels
/// Get label for an outlet.
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
/// Set label for an outlet.
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
self.outlet_labels.insert(outlet, label);
Ok(())
}
/// Set label for an outlet and return `self`.
pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
self.set_outlet_label(outlet, label)?;
Ok(self)
}
/// Find outlet by label.
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
}
// misc
/// Computes an evalutation order for the graph inputs and outputs
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
Ok(())
}
/// Performs a sanity check on network connections.
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}
impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
O: fmt::Debug
+ fmt::Display
+ From<crate::ops::konst::Const>
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ Hash
+ 'static,
{
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor,
) -> TractResult<OutletId> {
let v = v.into_arc_tensor();
let fact = F::from(v.clone());
let name = name.into();
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
}
}
impl<F, O> fmt::Display for Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}More examples
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let b = args_1!(model.node_input_facts(node.id)?);
if let Some(b_shape) = b.shape.as_concrete() {
Ok(Some(self.new_mat_mul_unary_finite(model, node, b_shape, b.datum_type)?))
} else {
Ok(None)
}
}
as_op!();
}
impl MatMulUnary {
fn new_mat_mul_unary_finite(
&self,
model: &TypedModel,
node: &TypedNode,
b_shape: &[usize],
b_dt: DatumType,
) -> TractResult<TypedModelPatch> {
let mut patch = TypedModelPatch::default();
let mut wire = patch.tap_model(model, node.inputs[0])?;
let c_dt = output_type(self.a.datum_type());
let (m, k, n, c_shape) = compute_shape(self.a.shape(), b_shape, self.axes)?;
let mmm = tract_linalg::ops()
.mmm(self.a.datum_type(), b_dt, c_dt, Some(m), Some(k), Some(n))
.with_context(|| {
format!(
"No matrix multiplier for {:?}x{:?} to {:?}",
self.a.datum_type(),
b_dt,
c_dt
)
})?;
let mut a_iter_shape: TVec<usize> = self.a.shape().into();
a_iter_shape[self.axes.a_m] = 1;
a_iter_shape[self.axes.a_k] = 1;
let packed_as = Array::from_shape_fn(&*a_iter_shape, |a_prefix| unsafe {
let offset = a_prefix
.as_array_view()
.iter()
.zip(self.a.strides())
.map(|(x, s)| *x as isize * s)
.sum::<isize>()
* self.a.datum_type().size_of() as isize;
let mut pa = Tensor::uninitialized_aligned_dt(
self.a.datum_type(),
&[mmm.a_pack().len(k, m)],
mmm.a_pack().alignment(),
)
.unwrap();
mmm.a_pack().pack(
&mut pa.view_mut(),
TensorView::from_bytes(&self.a, offset, self.a.shape(), self.a.strides()),
self.axes.a_k,
self.axes.a_m,
);
(pa.into_arc_tensor(), vec![ProtoFusedSpec::Store])
});
unsafe {
let mut packed_b_shape: TVec<usize> = b_shape.into();
packed_b_shape.remove(self.axes.b_k.max(self.axes.b_n));
packed_b_shape.remove(self.axes.b_k.min(self.axes.b_n));
packed_b_shape.push(mmm.b_pack().len(k, n));
wire = patch.wire_node(
format!("{}.pack", &*node.name),
super::MatMatMulPack {
packer: mmm.b_pack(),
k_axis: self.axes.b_k,
mn_axis: self.axes.b_n,
},
&[wire],
)?[0];
let b_storage = mmm.b_packed(b_dt.size_of(), k);
let geometry = ConcreteMatMulGeometry { m, k, n, b_storage };
wire = patch.wire_node(
format!("{}.matmatmul", &*node.name),
LirMatMulUnary {
c_fact: c_dt.fact(&c_shape),
geometry: MatMulGeometry::Concrete(geometry),
micro_ops: packed_as,
c_m_axis: self.axes.c_m,
c_n_axis: self.axes.c_n,
c_final_shape: c_shape.into(),
reshape_post: vec![],
mmm,
},
&[wire],
)?[0];
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
patch.obliterate(node.id)?;
}
Ok(patch)
}
fn declutter_precusor_is_concat(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
{
let mut patch = TypedModelPatch::new("split over k-concatenated input");
if concat.axis == self.axes.b_k {
let concat_node = model.node(node.inputs[0].node);
let offsets = concat
.offsets(&model.node_input_facts(concat_node.id)?)?
.iter()
.map(|x| x.to_usize())
.collect::<TractResult<Vec<usize>>>()?;
let mut wires = vec![];
for (ix, input) in concat_node.inputs.iter().enumerate() {
let wire = patch.tap_model(model, *input)?;
let a = self.a.slice(self.axes.a_k, offsets[ix], offsets[ix + 1])?;
let wire = patch.wire_node(
format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
MatMulUnary { a: a.into_arc_tensor(), ..self.clone() },
&[wire],
)?[0];
wires.push(wire)
}
let mut wire = wires[0];
for (ix, w) in wires[1..].iter().enumerate() {
wire = patch.wire_node(
format!("{}.k-add-{}", node.name, ix),
crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
&[wire, *w],
)?[0];
}
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
return Ok(Some(patch));
}
}
Ok(None)
}224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if self.0.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?
== inputs[0].datum_type
&& inputs[0] == inputs[1]
{
Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
MergeOpUnicast(self.0.clone()),
)?))
} else {
Ok(None)
}
}
as_op!();
}
#[derive(Debug, Clone, Hash)]
pub struct MergeOpUnicast(pub Box<dyn BinMiniOp>);
impl_dyn_hash!(MergeOpUnicast);
impl Op for MergeOpUnicast {
fn name(&self) -> Cow<str> {
format!("{}Unicast", self.0.name()).into()
}
op_as_typed_op!();
}
impl EvalOp for MergeOpUnicast {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (a, b) = args_2!(inputs);
let mut b = b.into_tensor();
self.0.eval_unicast_in_place(&a, &mut b)?;
Ok(tvec!(b.into_tvalue()))
}
}
impl TypedOp for MergeOpUnicast {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
debug_assert_eq!(inputs[0].shape, inputs[1].shape);
Ok(tvec!(inputs[0].clone()))
}
fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
Ok(self
.0
.cost_per_element(inputs[0].datum_type)
.into_iter()
.map(|(c, n)| (c, count.clone() * n))
.collect())
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
self.0.declutter(model, node)
}
as_op!();
}
#[macro_export]
macro_rules! bin_to_super_type {
($func:ident, $Op:ident,
$(codegen: $codegen:expr,)?
$(cost: $cost:expr,)?
$(declutter: $declutter:expr,)?
$(eval_override: $eval_override: expr,)?
$(linalg: $linalg:ident,)?
$(operating_datum_type: $operating_datum_type:expr,)?
$(out_of_place: $out_of_place:expr,)?
$(validation: $validation:expr,)?
$(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
$( [$($typ:ident),*] => $cab:expr),*) => {
#[derive(Debug, Clone, Hash)]
pub struct $Op;
tract_data::internal::impl_dyn_hash!($Op);
#[allow(clippy::redundant_closure_call)]
impl $crate::ops::binary::BinMiniOp for $Op {
fn name(&self) -> &'static str {
stringify!($Op)
}
fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
$(
$(if a.datum_type() == $typ::datum_type() {
let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
let a = a.to_scalar::<$typ>()?;
let b = b.as_slice_mut::<$typ>()?;
unsafe {
for i in 0..b.len() {
let mut c = $typ::default();
cab(&mut c, a, b.get_unchecked_mut(i));
b[i] = c;
}
}
return Ok(())
}
)*
)*
$(
$(
$(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
let a = a.to_scalar::<$typ_dt>()?;
let b = b.as_slice_mut::<$typ_dt>()?;
unsafe {
for i in 0..b.len() {
let mut c = $typ_dt::default();
cab(&mut c, a, b.get_unchecked_mut(i), zp, scale);
b[i] = c;
}
}
return Ok(())
}
)*
)*
)?
bail!("{} does not support {:?} (inplace uniform)", self.name(), a.datum_type());
}
fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
$(
$(if a.datum_type() == $typ::datum_type() {
let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
let a = a.as_slice::<$typ>()?;
let b = b.as_slice_mut::<$typ>()?;
unsafe {
for i in 0..a.len() {
let mut c = $typ::default();
cab(&mut c, &a[i], b.get_unchecked(i));
*b.get_unchecked_mut(i) = c;
}
}
return Ok(())
}
)*
)*
$(
$(
$(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
let a = a.as_slice::<$typ_dt>()?;
let b = b.as_slice_mut::<$typ_dt>()?;
unsafe {
for i in 0..a.len() {
let mut c = $typ_dt::default();
cab(&mut c, &a[i], b.get_unchecked(i), zp, scale);
*b.get_unchecked_mut(i) = c;
}
}
return Ok(())
}
)*
)*
)?
bail!("{} does not support {:?} (inplace)", self.name(), a.datum_type());
}
fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
$(if $out_of_place(c, a, b)? { return Ok(()) } )?
$(
$(if c.datum_type() == $typ::datum_type() {
let a = a.to_array_view::<$typ>()?;
let b = b.to_array_view::<$typ>()?;
let mut c = c.to_array_view_mut::<$typ>()?;
$crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
return Ok(())
})*
)*
$(
$(
$(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
let a = a.to_array_view::<$typ_dt>()?;
let b = b.to_array_view::<$typ_dt>()?;
let mut c = c.to_array_view_mut::<$typ_dt>()?;
$crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
return Ok(())
}
)*
)*
)?
bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
}
fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
// c and a are same type
$(
$(if b.datum_type() == $typ::datum_type() {
let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
let b = b.to_array_view::<$typ>()?;
let mut a = a.to_array_view_mut::<$typ>()?;
$crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
return Ok(())
})*
)*
/*
$(
$(
$(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
let mut a = a.to_array_view_mut::<$typ_dt>()?;
let b = b.to_array_view::<$typ_dt>()?;
$crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, a, b, zp, scale));
return Ok(())
}
)*
)*
)?
*/
bail!("{} does not support {:?} (out of place)", self.name(), a.datum_type());
}
$(fn eval(&self, a: TValue, b: TValue) -> TractResult<Tensor> {
$eval_override(a, b)
})?
fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
if a.unquantized() == b.unquantized() {
if a.is_quantized() || !b.is_quantized() {
return Ok(a)
}
else {
return Ok(b)
}
}
self.operating_datum_type(a, b)
}
$(
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
($declutter)(self, model, node)
}
)?
$(
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
a: &Arc<Tensor>,
) -> TractResult<Option<TypedModelPatch>> {
($codegen)(self, model, node, a)
}
)?
$(
fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
($cost)(dt)
}
)?
$(
fn validation(&self) -> Validation {
$validation
}
)?
$(
fn as_linalg_binop(&self) -> Option<tract_linalg::mmm::BinOp> {
Some(tract_linalg::mmm::BinOp::$linalg)
}
)?
$(
fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
($operating_datum_type)(a, b)
})?
}
pub fn $func() -> $crate::ops::binary::TypedBinOp {
$crate::ops::binary::TypedBinOp(Box::new($Op))
}
};
}
macro_rules! bin_to_bool {
($func:ident, $Op:ident,
$( codegen: $codegen:expr, )?
$( cost: $cost:expr, )?
$( declutter: $declutter:expr, )?
$( operating_datum_type: $operating_datum_type:expr, )?
$( [$($typ:ident),*] => $cab:expr),*) => {
#[derive(Debug, Clone, Hash)]
pub struct $Op;
tract_data::internal::impl_dyn_hash!($Op);
impl $crate::ops::binary::BinMiniOp for $Op {
fn name(&self) -> &'static str {
stringify!($Op)
}
fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
$(
$(if a.datum_type() == $typ::datum_type() {
let cab: fn(&mut bool, &bool, &bool) -> () = $cab;
let a = a.to_scalar::<bool>()?;
let b = b.as_slice_mut::<bool>()?;
unsafe {
for i in 0..b.len() {
let mut c = bool::default();
cab(&mut c, a, b.get_unchecked(i));
*b.get_unchecked_mut(i) = c;
}
}
return Ok(())
}
)*
)*
bail!("{} does not support {:?} (inplace uniform)", self.name(), a.datum_type());
}
#[allow(unreachable_code)]
fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
$(
$(if a.datum_type() == $typ::datum_type() {
let cab: fn(&mut bool, &bool, &bool) -> () = $cab;
let a = a.as_slice::<bool>()?;
let b = b.as_slice_mut::<bool>()?;
unsafe {
for i in 0..a.len() {
let mut c = bool::default();
cab(&mut c, a.get_unchecked(i), b.get_unchecked(i));
*b.get_unchecked_mut(i) = c;
}
}
return Ok(())
}
)*
)*
bail!("{} does not support {:?}", self.name(), a.datum_type());
}
fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
$(
$(if a.datum_type() == $typ::datum_type() {
let cab: fn(&mut bool, &$typ, &$typ) -> () = $cab;
let a = a.to_array_view::<$typ>()?;
let b = b.to_array_view::<$typ>()?;
let mut c = c.to_array_view_mut::<bool>()?;
ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(cab);
return Ok(())
}
)*
)*
bail!("{} does not support {:?}", self.name(), a.datum_type());
}
fn eval_in_a(&self, a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
bail!("{} does not support {:?}", self.name(), a.datum_type());
}
fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
Ok(bool::datum_type())
}
$(
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
($codegen)(self, model, node)
}
)?
$(
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
($declutter)(self, model, node)
}
)?
$(
fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
($cost)(dt)
}
)?
$(
fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
($operating_datum_type)(a, b)
})?
}
pub fn $func() -> $crate::ops::binary::TypedBinOp {
$crate::ops::binary::TypedBinOp(Box::new($Op))
}
};
}
#[derive(Debug)]
pub(crate) struct OneUniformInput {
pub uni: Arc<Tensor>,
pub var: OutletId,
pub left_is_uniform: bool,
}
pub(crate) fn one_input_is_uniform(
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<OneUniformInput>> {
if let &[a, b] = &*model.node_input_facts(node.id)? {
let uni = if let Some(a) = &a.uniform {
OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
} else if let Some(b) = &b.uniform {
OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
} else {
return Ok(None);
};
let var_fact = [a, b][uni.left_is_uniform as usize];
let uni_fact = [a, b][!uni.left_is_uniform as usize];
if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
return Ok(Some(uni))
}
}
Ok(None)
}110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
let start =
if self.start_input { inputs[1].konst.clone() } else { Some(rctensor0(TDim::zero())) };
let end = if self.end_input {
inputs[1 + self.start_input as usize].konst.clone()
} else {
Some(rctensor0(inputs[0].shape[self.axis].clone()))
};
if let (Some(start), Some(end)) = (start, end) {
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&[node.inputs[0]],
crate::ops::array::Slice {
axis: self.axis,
start: start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone(),
end: end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone(),
},
)?));
}
Ok(None)
}152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
fn declutter_const_initializer(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
for (ix, mapping) in self.input_mapping.iter().enumerate() {
if let InputMapping::State { initializer: StateInitializer::FromInput(n) } = mapping {
if let Some(i) = inputs[*n].konst.as_ref() {
let mut op = self.clone();
op.input_mapping[ix] =
InputMapping::State { initializer: StateInitializer::Value(i.clone()) };
op.input_mapping =
Self::remove_outer_input_from_mappings(&op.input_mapping, *n);
let mut inputs = node.inputs.clone();
inputs.remove(*n);
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
}
}
}
Ok(None)
}52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
fn codegen_compare_to_zero(
op: &dyn BinMiniOp,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let facts = model.node_input_facts(node.id)?;
if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
let dt = facts[0].datum_type;
if (dt.is_signed() || dt.is_float()) && *uniform.uni == Tensor::zero_scalar_dt(dt)? {
let reversed = uniform.left_is_uniform;
let mapped = || -> Box<dyn ElementWiseMiniOp> {
macro_rules! m {
($bin: ty, $same: expr, $other: expr) => {
if op.is::<$bin>() {
return if reversed {Box::new($other) } else {Box::new($same)}
};
}
}
m!(Less, LessThanZero {}, GreaterEqualThanZero {});
m!(LessEqual, LessEqualThanZero {}, GreaterThanZero {});
m!(Greater, GreaterThanZero {}, LessEqualThanZero {});
m!(GreaterEqual, GreaterEqualThanZero {}, LessThanZero {});
unreachable!();
};
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&[uniform.var],
ElementWiseOp(mapped()),
)?));
}
}
Ok(None)
}sourcepub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>>
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>>
Get output tensor information for a node.
Examples found in repository?
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}
// outlet labels
/// Get label for an outlet.
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
self.outlet_labels.get(&outlet).map(|s| &**s)
}
/// Set label for an outlet.
pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
self.outlet_labels.insert(outlet, label);
Ok(())
}
/// Set label for an outlet and return `self`.
pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
self.set_outlet_label(outlet, label)?;
Ok(self)
}
/// Find outlet by label.
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
}
// misc
/// Computes an evalutation order for the graph inputs and outputs
pub fn eval_order(&self) -> TractResult<Vec<usize>> {
eval_order(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
Ok(())
}
/// Performs a sanity check on network connections.
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
#[inline]
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}
impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
O: fmt::Debug
+ fmt::Display
+ From<crate::ops::konst::Const>
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ Hash
+ 'static,
{
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor,
) -> TractResult<OutletId> {
let v = v.into_arc_tensor();
let fact = F::from(v.clone());
let name = name.into();
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
}
}
impl<F, O> fmt::Display for Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}More examples
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
where
Eval: for<'a, 'b, 'c> FnMut(
&'a mut SessionState,
Option<&'b mut (dyn OpState + 'static)>,
&'c Node<F, O>,
TVec<TValue>,
) -> Result<TVec<TValue>, E>,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
{
let &mut SimpleState {
ref plan,
ref mut session_state,
ref mut states,
ref mut values,
..
} = self;
let plan = plan.borrow();
let model = plan.model().borrow();
for (step, n) in plan.order.iter().enumerate() {
let node = model.node(*n);
trace!("Running step {}, node {}", step, node);
let mut inputs: TVec<TValue> = tvec![];
for i in &node.inputs {
trace!(" use input {:?}", i);
let prec_node = model.node(i.node);
let prec = values[i.node].as_ref().ok_or_else(|| {
format_err!("Computing {}, precursor {} not done:", node, prec_node)
})?;
inputs.push(prec[i.slot].clone())
}
for flush in &plan.flush_lists[step] {
trace!(" Ran {} can now flush {}", node, model.node(*flush));
values[*flush] = None;
}
if cfg!(debug_assertions) {
let facts = model.node_input_facts(node.id)?;
if facts.len() != inputs.len() {
bail!(
"Evaluating {}: expected {} inputs, got {}",
node,
facts.len(),
inputs.len()
);
}
for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: input {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
.map_err(|e| e.into())?;
if plan.has_unresolved_symbols {
for (o, v) in node.outputs.iter().zip(vs.iter()) {
if let Ok(f) = o.fact.to_typed_fact() {
for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
Self::resolve(
&mut session_state.resolved_symbols,
&dim_abstract,
*dim_concrete as i64,
);
}
}
}
}
if cfg!(debug_assertions) {
let facts = model.node_output_facts(node.id)?;
if facts.len() != vs.len() {
bail!(
"Evaluating {}: expected {} outputs, got {}",
node,
facts.len(),
vs.len()
);
}
for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
if node.outputs[ix].successors.len() == 0 {
continue;
}
if !f.matches(v, Some(&session_state.resolved_symbols))? {
bail!(
"Evaluating {}: output {:?}, expected {:?}, got {:?}",
node,
ix,
f,
v
);
}
}
}
values[node.id] = Some(vs);
}
}
Ok(())
}sourcepub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F>
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F>
Get tensor information for a single outlet.
Examples found in repository?
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
let input = self.input_outlets()?[ix];
self.outlet_fact(input)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let input = self.input_outlets()?[ix];
self.outlet_fact_mut(input)
}
/// Set the `ix`-th input tensor type information.
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
let outlet = self.inputs[input];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th input tensor type information and return `self`.
pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
self.set_input_fact(input, fact)?;
Ok(self)
}
// Outputs
/// Get model outputs.
pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.outputs)
}
/// Guess outputs from the topology: node or nodes with no successors.
pub fn auto_outputs(&mut self) -> TractResult<()> {
let outputs = self
.nodes
.iter()
.flat_map(|n| {
let id = n.id;
n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
(OutletId::new(id, ix), output_fact.successors.len())
})
})
.filter(|(_f, succs)| *succs == 0)
.map(|(f, _)| f)
.collect();
self.outputs = outputs;
Ok(())
}
/// Change model outputs.
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
self.outputs = outputs.to_vec();
Ok(())
}
/// Change model outputs and return `self`.
pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
self.set_output_outlets(outputs)?;
Ok(self)
}
/// Set model outputs by node names.
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut labels: HashMap<Cow<str>, OutletId> =
self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
for n in self.nodes() {
for ix in 0..n.outputs.len() {
labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
}
}
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| {
let s = s.as_ref();
labels
.get(s)
.cloned()
.or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
.ok_or_else(|| format_err!("Node {} not found", s))
})
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
/// Set model outputs by node names and return `self`.
pub fn with_output_names(
mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_output_names(outputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
/// Set the `ix`-th output tensor type information.
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th output tensor type information and return `self`.
pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
self.set_output_fact(output, fact)?;
Ok(self)
}
// nodes and their facts
/// Iterate over all node names.
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{}\"", name))
}
/// Find a node by its name.
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
/// Borrow mutably a node by its name.
pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&mut self.nodes[id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
Ok(())
}
/// Find a node by its id.
pub fn node(&self, id: usize) -> &Node<F, O> {
&self.nodes[id]
}
/// Find a node by its id.
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
&mut self.nodes[id]
}
/// Access the nodes table.
pub fn nodes(&self) -> &[Node<F, O>] {
&self.nodes
}
/// Access the nodes table.
pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
&mut self.nodes
}
/// Get input and output tensor information for a node.
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}More examples
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
pub fn tap_model(&mut self, model: &Graph<F, O>, outlet: OutletId) -> TractResult<OutletId> {
let fact = model.outlet_fact(outlet)?;
let id = self.add_source(
format!("incoming-{}/{}", outlet.node, outlet.slot),
dyn_clone::clone(fact),
)?;
self.incoming.insert(id, outlet);
Ok(id)
}
pub unsafe fn shunt_outside_unchecked(
&mut self,
outlet: OutletId,
by: OutletId,
) -> TractResult<()> {
self.shunt_outlet_by.insert(outlet, by);
Ok(())
}
/// Replace an Outlet in the target model by one from the patch.
pub fn shunt_outside(
&mut self,
model: &Graph<F, O>,
outlet: OutletId,
by: OutletId,
) -> TractResult<()> {
let original_fact = model.outlet_fact(outlet)?;
let new_fact = self.model.outlet_fact(by)?;
if !original_fact.compatible_with(new_fact) {
bail!("Trying to substitute a {:?} by {:?}.\n{:?}", original_fact, new_fact, self);
}
self.shunt_outlet_by.insert(outlet, by);
Ok(())
}136 137 138 139 140 141 142 143 144 145 146 147
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.start.is_zero() && (self.end == model.outlet_fact(node.inputs[0])?.shape[self.axis])
{
Ok(Some(TypedModelPatch::shunt_one_op(model, node)?.with_context("noop")))
} else {
Ok(None)
}
}227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let input_fact = model.outlet_fact(node.inputs[0])?;
if node.inputs.len() == 2
&& model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
== Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
{
Ok(Some(
TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
.with_context("b0 is zero"),
))
} else {
Ok(None)
}
}324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
pub fn full_axis_tracking(model: &TypedModel) -> TractResult<Vec<AxisTracking>> {
let mut axes: Vec<AxisTracking> = vec![];
for node in model.eval_order()? {
for slot in 0..model.node(node).outputs.len() {
let outlet = OutletId::new(node, slot);
let input_fact = model.outlet_fact(outlet)?;
'axis: for axis in 0..input_fact.rank() {
if axes.iter().any(|tracking| tracking.outlets.get(&outlet) == Some(&axis)) {
continue 'axis;
}
axes.push(AxisTracking::for_outlet_and_axis(model, outlet, axis)?);
}
}
}
Ok(axes)
}- src/ops/binary.rs
- src/ops/matmul/mir_quant.rs
- src/ops/quant.rs
- src/ops/matmul/mir.rs
- src/plan.rs
- src/ops/downsample/conv.rs
- src/ops/matmul/mir_unary.rs
- src/ops/array/gather_nd.rs
- src/ops/array/gather.rs
- src/model/typed.rs
- src/ops/math/mod.rs
- src/ops/cnn/conv/unary.rs
- src/ops/scan/mir.rs
- src/ops/matmul/lir_unary.rs
- src/ops/cnn/deconv/unary.rs
sourcepub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F>
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F>
Get tensor information for a single outlet.
Examples found in repository?
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let input = self.input_outlets()?[ix];
self.outlet_fact_mut(input)
}
/// Set the `ix`-th input tensor type information.
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
let outlet = self.inputs[input];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th input tensor type information and return `self`.
pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
self.set_input_fact(input, fact)?;
Ok(self)
}
// Outputs
/// Get model outputs.
pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.outputs)
}
/// Guess outputs from the topology: node or nodes with no successors.
pub fn auto_outputs(&mut self) -> TractResult<()> {
let outputs = self
.nodes
.iter()
.flat_map(|n| {
let id = n.id;
n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
(OutletId::new(id, ix), output_fact.successors.len())
})
})
.filter(|(_f, succs)| *succs == 0)
.map(|(f, _)| f)
.collect();
self.outputs = outputs;
Ok(())
}
/// Change model outputs.
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
self.outputs = outputs.to_vec();
Ok(())
}
/// Change model outputs and return `self`.
pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
self.set_output_outlets(outputs)?;
Ok(self)
}
/// Set model outputs by node names.
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut labels: HashMap<Cow<str>, OutletId> =
self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
for n in self.nodes() {
for ix in 0..n.outputs.len() {
labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
}
}
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| {
let s = s.as_ref();
labels
.get(s)
.cloned()
.or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
.ok_or_else(|| format_err!("Node {} not found", s))
})
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
/// Set model outputs by node names and return `self`.
pub fn with_output_names(
mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_output_names(outputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}sourcepub fn outlets_fact_mut(
&mut self,
outlets: &[OutletId]
) -> TractResult<TVec<&mut F>>
pub fn outlets_fact_mut(
&mut self,
outlets: &[OutletId]
) -> TractResult<TVec<&mut F>>
Get multiple mutable tensor information for outlets.
sourcepub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()>
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()>
Set tensor information for a single outlet.
Examples found in repository?
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
let outlet = self.inputs[input];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th input tensor type information and return `self`.
pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
self.set_input_fact(input, fact)?;
Ok(self)
}
// Outputs
/// Get model outputs.
pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
Ok(&self.outputs)
}
/// Guess outputs from the topology: node or nodes with no successors.
pub fn auto_outputs(&mut self) -> TractResult<()> {
let outputs = self
.nodes
.iter()
.flat_map(|n| {
let id = n.id;
n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
(OutletId::new(id, ix), output_fact.successors.len())
})
})
.filter(|(_f, succs)| *succs == 0)
.map(|(f, _)| f)
.collect();
self.outputs = outputs;
Ok(())
}
/// Change model outputs.
pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
self.outputs = outputs.to_vec();
Ok(())
}
/// Change model outputs and return `self`.
pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
self.set_output_outlets(outputs)?;
Ok(self)
}
/// Set model outputs by node names.
pub fn set_output_names(
&mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<()> {
let mut labels: HashMap<Cow<str>, OutletId> =
self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
for n in self.nodes() {
for ix in 0..n.outputs.len() {
labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
}
}
let ids: Vec<OutletId> = outputs
.into_iter()
.map(|s| {
let s = s.as_ref();
labels
.get(s)
.cloned()
.or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
.ok_or_else(|| format_err!("Node {} not found", s))
})
.collect::<TractResult<_>>()?;
self.outputs = ids;
Ok(())
}
/// Set model outputs by node names and return `self`.
pub fn with_output_names(
mut self,
outputs: impl IntoIterator<Item = impl AsRef<str>>,
) -> TractResult<Self> {
self.set_output_names(outputs)?;
Ok(self)
}
/// Get the `ix`-th input tensor type information.
pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
let output = self.output_outlets()?[ix];
self.outlet_fact(output)
}
/// Get the `ix`-th input tensor type information, mutably.
pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
let output = self.output_outlets()?[ix];
self.outlet_fact_mut(output)
}
/// Set the `ix`-th output tensor type information.
pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
let outlet = self.outputs[output];
self.set_outlet_fact(outlet, fact)
}
/// Set the `ix`-th output tensor type information and return `self`.
pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
self.set_output_fact(output, fact)?;
Ok(self)
}
// nodes and their facts
/// Iterate over all node names.
pub fn node_names(&self) -> impl Iterator<Item = &str> {
self.nodes.iter().map(|s| &*s.name)
}
pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
self.nodes
.iter()
.find(|n| n.name == name)
.map(|n| n.id)
.with_context(|| format!("No node found for name: \"{}\"", name))
}
/// Find a node by its name.
pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&self.nodes[id])
}
/// Borrow mutably a node by its name.
pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
let id: usize = self.node_id_by_name(name.as_ref())?;
Ok(&mut self.nodes[id])
}
pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
self.node_mut(id).name = name.to_string();
Ok(())
}
/// Find a node by its id.
pub fn node(&self, id: usize) -> &Node<F, O> {
&self.nodes[id]
}
/// Find a node by its id.
pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
&mut self.nodes[id]
}
/// Access the nodes table.
pub fn nodes(&self) -> &[Node<F, O>] {
&self.nodes
}
/// Access the nodes table.
pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
&mut self.nodes
}
/// Get input and output tensor information for a node.
pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
}
/// Get input tensor information for a node.
pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
}
/// Get output tensor information for a node.
pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
}
// outlets
/// Get tensor information for a single outlet.
pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
let outlets = &self.nodes[outlet.node].outputs;
outlets
.get(outlet.slot)
.map(|o| &o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get tensor information for a single outlet.
pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
let outlets = &mut self.nodes[outlet.node].outputs;
outlets
.get_mut(outlet.slot)
.map(|o| &mut o.fact)
.with_context(|| format!("Invalid outlet reference: {:?}", outlet))
}
/// Get multiple mutable tensor information for outlets.
pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
unsafe {
outlets
.iter()
.map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
.collect()
}
}
/// Set tensor information for a single outlet.
pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
let outlets = &mut self.nodes[outlet.node].outputs;
if outlets.len() <= outlet.slot {
bail!("Invalid outlet refererence: {:?}", outlet)
}
outlets[outlet.slot].fact = fact;
Ok(())
}
/// Set tensor information for a single outlet and return `self`.
pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
self.set_outlet_fact(outlet, fact)?;
Ok(self)
}sourcepub fn with_outlet_fact(self, outlet: OutletId, fact: F) -> TractResult<Self>
pub fn with_outlet_fact(self, outlet: OutletId, fact: F) -> TractResult<Self>
Set tensor information for a single outlet and return self.
sourcepub fn outlet_label(&self, outlet: OutletId) -> Option<&str>
pub fn outlet_label(&self, outlet: OutletId) -> Option<&str>
Get label for an outlet.
Examples found in repository?
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
fn translate_model_with_mappings(
&self,
source: &Graph<TI1, O1>,
) -> TractResult<(Graph<TI2, O2>, HashMap<OutletId, OutletId>)> {
let mut target = Graph::default();
let mut mapping = HashMap::new();
for old_id in source.eval_order()? {
let node = source.node(old_id);
trace!("Translating {} {:?}", node, self);
let outlets = self
.translate_node(source, node, &mut target, &mapping)
.with_context(|| format!("Translating node {} {:?}", node, self))?;
for (ix, outlet) in outlets.into_iter().enumerate() {
mapping.insert(OutletId::new(node.id, ix), outlet);
if let Some(label) = source.outlet_label(OutletId::new(node.id, ix)) {
target.set_outlet_label(outlet, label.to_string())?;
}
}
}
// do not drop inputs, even if they are useless, to maintain interface
for i in source.input_outlets()? {
if !mapping.contains_key(i) {
let node = source.node(i.node);
trace!("Translate useless source {}", node);
let outlets = self
.translate_node(source, node, &mut target, &mapping)
.with_context(|| format!("Translating input {} {:?}", node, self))?;
mapping.insert(*i, outlets[0]);
}
}
// maintaining order of i/o interface
target.inputs = source.input_outlets()?.iter().map(|i| mapping[i]).collect();
target.outputs = source.output_outlets()?.iter().map(|o| mapping[o]).collect();
target.symbol_table = source.symbol_table.clone();
target.properties = source.properties.clone();
Ok((target, mapping))
}More examples
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}sourcepub fn set_outlet_label(
&mut self,
outlet: OutletId,
label: String
) -> TractResult<()>
pub fn set_outlet_label(
&mut self,
outlet: OutletId,
label: String
) -> TractResult<()>
Set label for an outlet.
Examples found in repository?
More examples
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
fn translate_model_with_mappings(
&self,
source: &Graph<TI1, O1>,
) -> TractResult<(Graph<TI2, O2>, HashMap<OutletId, OutletId>)> {
let mut target = Graph::default();
let mut mapping = HashMap::new();
for old_id in source.eval_order()? {
let node = source.node(old_id);
trace!("Translating {} {:?}", node, self);
let outlets = self
.translate_node(source, node, &mut target, &mapping)
.with_context(|| format!("Translating node {} {:?}", node, self))?;
for (ix, outlet) in outlets.into_iter().enumerate() {
mapping.insert(OutletId::new(node.id, ix), outlet);
if let Some(label) = source.outlet_label(OutletId::new(node.id, ix)) {
target.set_outlet_label(outlet, label.to_string())?;
}
}
}
// do not drop inputs, even if they are useless, to maintain interface
for i in source.input_outlets()? {
if !mapping.contains_key(i) {
let node = source.node(i.node);
trace!("Translate useless source {}", node);
let outlets = self
.translate_node(source, node, &mut target, &mapping)
.with_context(|| format!("Translating input {} {:?}", node, self))?;
mapping.insert(*i, outlets[0]);
}
}
// maintaining order of i/o interface
target.inputs = source.input_outlets()?.iter().map(|i| mapping[i]).collect();
target.outputs = source.output_outlets()?.iter().map(|o| mapping[o]).collect();
target.symbol_table = source.symbol_table.clone();
target.properties = source.properties.clone();
Ok((target, mapping))
}251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
let prior_target_inputs = target.input_outlets()?.len();
let prior_target_outputs = target.output_outlets()?.len();
let ModelPatch {
model: patch,
incoming: mut mapping,
shunt_outlet_by,
obliterate,
inputs: replaced_inputs,
..
} = self;
let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
let mut model_input_outlets = target.input_outlets()?.to_vec();
for node in patch.nodes {
if <Graph<F, O>>::is_source(&node.op)
&& mapping.contains_key(&OutletId::new(node.id, 0))
{
// this is a tap
continue;
}
let Node { id: patch_node_id, name, inputs, op, outputs } = node;
let n_outputs = outputs.len();
for dup in 0..target.nodes.len() {
if target.node(dup).op().same_as(op.as_ref())
&& inputs.len() == target.node(dup).inputs.len()
&& inputs
.iter()
.zip(target.node(dup).inputs.iter())
.all(|(patch_input, d)| mapping[patch_input] == *d)
{
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
}
continue;
}
}
let facts = outputs.into_iter().map(|of| of.fact).collect();
let added_node_id = target.add_node(name, op, facts)?;
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
}
all_inputs.insert(added_node_id, inputs);
if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
// this is actually an input replacement
model_input_outlets.iter_mut().for_each(|oo| {
if oo.node == replaced_inputs[&patch_node_id] {
oo.node = added_node_id;
}
});
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (outlet, by) in shunt_outlet_by {
let replace_by = mapping[&by];
let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
for succ in succs {
target.add_edge(replace_by, succ)?;
}
for o in target.outputs.iter_mut() {
if *o == outlet {
*o = replace_by;
}
}
if let Some(label) = target.outlet_labels.remove(&outlet) {
target.set_outlet_label(replace_by, label)?;
}
}
if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
bail!("Duplicate usage of node as output");
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (node, inputs) in all_inputs {
for (ix, input) in inputs.into_iter().enumerate() {
target.add_edge(mapping[&input], InletId::new(node, ix))?;
}
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for node in obliterate {
target.node_mut(node).op = target.create_dummy();
}
debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
target.set_input_outlets(&model_input_outlets)?;
Ok(())
}sourcepub fn with_outlet_label(
self,
outlet: OutletId,
label: String
) -> TractResult<Self>
pub fn with_outlet_label(
self,
outlet: OutletId,
label: String
) -> TractResult<Self>
Set label for an outlet and return self.
sourcepub fn find_outlet_label(&self, label: &str) -> Option<OutletId>
pub fn find_outlet_label(&self, label: &str) -> Option<OutletId>
Find outlet by label.
sourcepub fn eval_order(&self) -> TractResult<Vec<usize>>
pub fn eval_order(&self) -> TractResult<Vec<usize>>
Computes an evalutation order for the graph inputs and outputs
Examples found in repository?
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
pub fn full_axis_tracking(model: &TypedModel) -> TractResult<Vec<AxisTracking>> {
let mut axes: Vec<AxisTracking> = vec![];
for node in model.eval_order()? {
for slot in 0..model.node(node).outputs.len() {
let outlet = OutletId::new(node, slot);
let input_fact = model.outlet_fact(outlet)?;
'axis: for axis in 0..input_fact.rank() {
if axes.iter().any(|tracking| tracking.outlets.get(&outlet) == Some(&axis)) {
continue 'axis;
}
axes.push(AxisTracking::for_outlet_and_axis(model, outlet, axis)?);
}
}
}
Ok(axes)
}More examples
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
fn full_pass(
&mut self,
session: &mut OptimizerSession,
new: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, &id) in new.eval_order()?.iter().enumerate().skip(self.2) {
let node = &new.nodes()[id];
let patch = (self.1)(node.op.as_ref(), session, new, node)
.with_context(|| format!("{:?} node {}", self, node))?;
if let Some(mut p) = patch {
p.push_context(format!("{:?} {}", self, node));
self.2 = ix + p.dont_apply_twice.is_some() as usize;
return Ok(Some(p));
}
}
Ok(None)
}78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
fn declutter_body_axes(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut suggestions = vec![];
for n in self.body.eval_order()? {
let node = self.body.node(n);
for suggestion in node.op.suggested_axis_changes()? {
let outlet = suggestion.0.as_outlet(node);
suggestions.push(AxisChange { outlet, op: suggestion.1 })
}
}
for suggestion in suggestions.into_iter() {
if let Some(op) =
self.try_body_axes_change(suggestion, true)?.and_then(|c| c.substitute_op)
{
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
op,
)?));
}
}
Ok(None)
}24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
if self.0.insert((n, suggestion.clone())) {
let outlet = suggestion.0.as_outlet(model.node(n));
let change = AxisChange { outlet, op: suggestion.1.clone() };
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
}
Ok(None)
}13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
fn next(&mut self, _session: &mut OptimizerSession, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for node in model.eval_order()? {
for output in &model.node(node).outputs {
for (a, b) in output.successors.iter().tuple_combinations() {
if patch.obliterate.contains(&b.node) {
continue;
}
let a = model.node(a.node);
let b = model.node(b.node);
if a.same_as(b) {
for slot in 0..b.outputs.len() {
let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
patch.obliterate(b.id)?;
}
}
}
}
}
Ok(Some(patch).filter(|p| !p.is_empty()))
}461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
pub fn check_edges(&self) -> TractResult<()> {
for node_id in self.eval_order()? {
let node = &self.nodes[node_id];
for (ix, input) in node.inputs.iter().enumerate() {
let prec = &self.nodes[input.node];
if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
bail!(
"Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
node.id,
ix,
prec
)
}
}
for (ix, output) in node.outputs.iter().enumerate() {
for succ in &output.successors {
if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
bail!(
"Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
node.id,
ix,
succ
)
}
}
}
}
Ok(())
}
/// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
crate::plan::SimplePlan::new(self)
}
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.inputs.len() != 1 {
return Ok(None);
}
let prec = &self.nodes()[node.inputs[0].node];
if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
Ok(Some(prec))
}
pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_prec(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
let mut node = self.node(id);
for _ in 0..count {
if let Some(next) = self.single_succ(node.id)? {
node = next
} else {
return Ok(None);
}
}
Ok(Some(node))
}
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
let node = &self.nodes()[id];
if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
return Ok(None);
}
let succ = node.outputs[0].successors[0];
let succ = &self.nodes()[succ.node];
if succ.inputs.len() != 1 {
return Ok(None);
}
Ok(Some(succ))
}
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
&self.nodes[outlet.node].outputs[outlet.slot].successors
}
}
impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
O: fmt::Debug
+ fmt::Display
+ From<crate::ops::konst::Const>
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ Hash
+ 'static,
{
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor,
) -> TractResult<OutletId> {
let v = v.into_arc_tensor();
let fact = F::from(v.clone());
let name = name.into();
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
}
}
impl<F, O> fmt::Display for Graph<F, O>
where
F: Fact + Hash + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}
}
impl<F, O> Graph<F, O>
where
F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
O: std::fmt::Display
+ std::fmt::Debug
+ Clone
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ 'static
+ std::hash::Hash
+ for<'a> std::convert::From<&'a O>,
Graph<F, O>: SpecialOps<F, O>,
{
#[cfg(debug_assertions)]
pub fn check_compact(&self) -> TractResult<()> {
let order = self.eval_order()?;
let useless_sources = self
.input_outlets()?
.iter()
.filter(|io| {
self.outlet_successors(**io).len() == 0
&& !self.output_outlets().unwrap().contains(io)
})
.count();
if order.len() + useless_sources != self.nodes.len() {
bail!(
"Eval order is {} long, nodes are {}, including {} unused sources",
order.len(),
self.nodes.len(),
useless_sources
);
}
if (0..order.len()).any(|ix| order[ix] != ix) {
bail!("eval order is not trivial");
}
let mut seen = std::collections::HashSet::new();
for (ix, n) in self.nodes.iter().enumerate() {
if ix != n.id {
bail!("Invalid node id: position is {}, node is {}", ix, n);
}
if seen.contains(&n.name) {
eprintln!("{}", self);
bail!("duplicate name {}", n.name);
}
seen.insert(&n.name);
}
Ok(())
}sourcepub fn check_edges(&self) -> TractResult<()>
pub fn check_edges(&self) -> TractResult<()>
Performs a sanity check on network connections.
Examples found in repository?
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
pub fn check_consistency(&self) -> TractResult<()> {
self.check_edges()?;
for node_id in &self.eval_order()? {
let input_facts = self.node_input_facts(*node_id)?;
let node = &self.nodes[*node_id];
if node.id != *node_id {
bail!("Node at position {} has id {}", node_id, node.id);
}
let output_facts = node.op.output_facts(&input_facts)?;
if node.outputs.len() != output_facts.len() {
bail!(
"Inconsistent model, node output count mismatch. Op says {}, node says {}. {}",
output_facts.len(),
node.outputs.len(),
node
);
}
if node
.outputs
.iter()
.map(|o| &o.fact)
.zip(output_facts.iter())
.any(|(a, b)| a.datum_type != b.datum_type || a.shape != b.shape)
{
bail!(
"Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
}
}
for node in &self.nodes {
for (ix, output) in node.outputs.iter().enumerate() {
output.fact.consistent().with_context(|| {
format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
})?
}
}
Ok(())
}sourcepub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>>
pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>>
Converts the model into a RunnableModel which fixes the inputs and outputs and allows passing data through the model.
Examples found in repository?
More examples
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut model = TypedModel::default();
let mut wires: TVec<OutletId> = inputs
.iter()
.enumerate()
.map(|(ix, v)| {
model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
})
.collect::<TractResult<_>>()?;
let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
let wire = unsafe {
if self.q_params.is_some() {
let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
op_ref.wire_as_quant_im2col(
&mut model,
"im2col-adhoc",
inputs[0].datum_type(),
&wires,
)?
} else {
self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
}
};
model.set_output_outlets(&[wire])?;
model.into_runnable()?.run(inputs)
}262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
ensure!(
inputs[0].rank() == inputs[1].rank(),
"Rank mismatch {:?} vs {:?}",
inputs[0],
inputs[1]
);
let mut model = TypedModel::default();
let a = model.add_const("source_a", inputs[0].clone().into_arc_tensor())?;
let b = model.add_const("source_b", inputs[1].clone().into_arc_tensor())?;
let bias = model.add_const("source_bias", inputs[2].clone().into_arc_tensor())?;
let mut input_outlets = tvec![a, b, bias];
for (i, t) in inputs.iter().enumerate().skip(3) {
input_outlets
.push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
}
let mut params = self.params.as_outlet_ids(
&mut model,
"qmatmul_unary",
&input_outlets,
inputs[0].datum_type(),
inputs[1].datum_type(),
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;
let new_op = MatMul { axes: self.axes };
let result = model.wire_node("adhoc.matmul", new_op, &[a, b])?[0];
let result = wire_matmul_quant(
&mut model,
"adhoc",
a,
b,
Some(bias),
self.axes,
result,
self.output_type,
¶ms,
)?;
model.set_output_outlets(&[result])?;
model.into_runnable()?.run(tvec![])
}34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
ensure!(inputs[0].rank() == self.a.rank(), "Rank mismatch {:?} vs {:?}", inputs[0], self.a);
let mut model = TypedModel::default();
let t_a = self.a.offset_u8_as_i8();
let a = model.add_const("source_a", self.a.clone())?;
let b = model.add_const("source_b", inputs[0].clone().into_arc_tensor())?;
let bias = if let Some(bias) = self.bias.clone() {
Some(model.add_const("source_bias", bias)?)
} else {
None
};
let mut input_outlets = tvec![a];
for (i, t) in inputs.iter().enumerate().skip(1) {
input_outlets
.push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
}
let mut params = self.params.as_outlet_ids(
&mut model,
"qmatmul_unary",
&input_outlets,
self.a.datum_type(),
inputs[0].datum_type(),
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;
let new_op = MatMulUnary { a: t_a, axes: self.axes };
let result = model.wire_node("adhoc.matmul", new_op, &[b])?[0];
let result = wire_matmul_quant(
&mut model,
"adhoc",
a,
b,
bias,
self.axes,
result,
self.output_type,
¶ms,
)?;
model.set_output_outlets(&[result])?;
model.into_runnable()?.run(tvec![])
}sourcepub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>>
pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>>
Examples found in repository?
More examples
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
use super::element_wise::*;
if let Some(prec) = model.single_prec(node.id)? {
if let Some(ew) = prec.op_as::<ElementWiseOp>() {
let repl = if ew.0.is::<Sqrt>() {
Some(rsqrt())
} else if ew.0.is::<Rsqrt>() {
Some(sqrt())
} else {
None
};
if let Some(repl) = repl {
let mut patch = TypedModelPatch::default();
let mut wire = patch.tap_model(model, prec.inputs[0])?;
wire = patch.wire_node(&node.name, repl, &[wire])?[0];
patch.shunt_outside(model, node.id.into(), wire)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
fn pull_downsample_up(
model: &TypedModel,
down_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
model.check_consistency()?;
let down_op = down_node.op_as::<Downsample>().unwrap();
if let Some(prec) = model.single_prec(down_node.id)? {
let (input_facts, output_facts) = model.node_facts(prec.id)?;
let invariants = prec.op.invariants(&input_facts, &output_facts)?;
debug!("Consider pull {:?} over {:?} (invariants: {:?})", down_op, prec, invariants);
if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
if let Some(p) = array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)? {
return Ok(Some(p))
}
} else if let Some(other_op) = prec.op_as::<AxisOp>() {
return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
} else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::ConvUnary>() {
return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
} else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
}
if let Some(above_axis) = invariants.unary_track_axis_up(down_op.axis, false) {
let mut patch = TypedModelPatch::default();
let mut inputs = vec![];
for (ix, &oo) in prec.inputs.iter().enumerate() {
let source = patch.tap_model(model, oo)?;
let mut op = down_op.clone();
op.axis = above_axis;
let ds = patch.wire_node(
format!("{}.{}-{}", down_node.name, prec.name, ix),
op,
[source].as_ref(),
)?;
inputs.push(ds[0]);
}
let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
return Ok(Some(patch));
}
}
Ok(None)
}pub fn single_prec_at(
&self,
id: usize,
count: usize
) -> TractResult<Option<&Node<F, O>>>
pub fn single_succ_at(
&self,
id: usize,
count: usize
) -> TractResult<Option<&Node<F, O>>>
sourcepub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>>
pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>>
Examples found in repository?
More examples
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
pub fn fuse_with_next<IO: Into<O>>(
patched_model: &Graph<F, O>,
node: &Node<F, O>,
new_op: IO,
) -> TractResult<ModelPatch<F, O>> {
let mut patch = ModelPatch::default();
let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
succ
} else {
bail!("Non single successor fuse attempt")
};
let new_op = new_op.into();
let by = patch.add_node(&*node.name, new_op, tvec!(succ.outputs[0].fact.clone()))?;
for (ix, i) in node.inputs.iter().enumerate() {
let o = patch.tap_model(patched_model, *i)?;
patch.add_edge(o, InletId::new(by, ix))?;
}
for ix in 0..node.outputs.len() {
patch.shunt_outside(
patched_model,
OutletId::new(succ.id, ix),
OutletId::new(by, ix),
)?;
}
Ok(patch)
}146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
fn declutter(
&self,
model: &TypedModel,
dequant: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut current = dequant;
let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
while let Some(quant) = model.single_succ(current.id)? {
let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
} else {
op.0.downcast_ref::<QuantizeLinearI8>()
.map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
}
} else {
None
};
if let Some((scale, zero_point, dt)) = q_params {
// first, try Op::quantize() on all ops in the chain
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
let mut next = model.single_succ(dequant.id)?.unwrap();
loop {
if let Some(op) = next
.op
.quantize(model, dequant, dt, scale, zero_point)
.with_context(|| format!("Quantizing {}", next))?
{
wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
} else {
break;
}
if next.id == current.id {
patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
return Ok(Some(patch));
} else {
next = model.single_succ(next.id)?.unwrap();
}
}
// or else make a lookup table
if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
let mut adhoc_model = TypedModel::default();
let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
let mut next = model.single_succ(dequant.id)?.unwrap();
let mut name = None;
// plug in dequant
wire = adhoc_model.wire_node(
&*dequant.name,
dequant.op.clone(),
[wire].as_ref(),
)?[0];
while next.id != quant.id {
name.get_or_insert(&*next.name);
wire =
adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
[0];
next = model.single_succ(next.id)?.unwrap();
}
// plug in quant
wire =
adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
adhoc_model.set_output_outlets(&[wire])?;
let input = (0u8..=255).collect::<Vec<u8>>();
let input = match dt {
DatumType::I8 => unsafe {
tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
},
DatumType::U8 => tensor1(&input),
_ => unreachable!(),
};
let output =
SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
let table: &[u8] = match dt {
DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
DatumType::U8 => output.as_slice::<u8>()?,
_ => unreachable!(),
};
let op = lookup_table((tract_linalg::ops().lut_u8)(table));
let mut patch = TypedModelPatch::default();
let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
return Ok(Some(patch));
}
}
let (input_facts, output_facts) = model.node_facts(quant.id)?;
let invariants = quant
.op
.invariants(&input_facts, &output_facts)
.with_context(|| format!("Querying invariants for {}", quant))?;
if invariants.element_wise() {
current = quant;
} else {
break;
}
}
Ok(None)
}sourcepub fn outlet_successors(&self, outlet: OutletId) -> &[InletId]
pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId]
Examples found in repository?
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
let output_2 = self
.outlet_successors(OutletId::new(i, 0))
.get(1)
.map(|o| format!("{:?}", o))
.unwrap_or_else(|| "".to_string());
writeln!(
fmt,
"{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
i,
input_1,
input_2,
output_1,
output_2,
self.nodes[i].op().name(),
self.nodes[i].name,
self.node_input_facts(i).unwrap(),
self.node_output_facts(i).unwrap(),
)?;
if self.nodes[i].inputs.len() > 2 {
writeln!(
fmt,
" | * inputs: {}",
self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
)?;
}
if self.nodes[i].outputs.len() > 1
|| self.outlet_successors((i, 0).into()).len() > 2
|| (self.outlet_label(i.into()).is_some()
&& self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
{
for o in 0..self.nodes[i].outputs.len() {
if self.outlet_successors((i, o).into()).len() > 0 {
writeln!(
fmt,
" | * output #{}: {} {}",
o,
self.outlet_label((i, o).into()).unwrap_or(""),
self.outlet_successors((i, o).into())
.iter()
.map(|s| format!("{:?}", s))
.join(", "),
)?;
}
}
}
}
writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
Ok(())
}
}
impl<F, O> Graph<F, O>
where
F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
O: std::fmt::Display
+ std::fmt::Debug
+ Clone
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone
+ 'static
+ std::hash::Hash
+ for<'a> std::convert::From<&'a O>,
Graph<F, O>: SpecialOps<F, O>,
{
#[cfg(debug_assertions)]
pub fn check_compact(&self) -> TractResult<()> {
let order = self.eval_order()?;
let useless_sources = self
.input_outlets()?
.iter()
.filter(|io| {
self.outlet_successors(**io).len() == 0
&& !self.output_outlets().unwrap().contains(io)
})
.count();
if order.len() + useless_sources != self.nodes.len() {
bail!(
"Eval order is {} long, nodes are {}, including {} unused sources",
order.len(),
self.nodes.len(),
useless_sources
);
}
if (0..order.len()).any(|ix| order[ix] != ix) {
bail!("eval order is not trivial");
}
let mut seen = std::collections::HashSet::new();
for (ix, n) in self.nodes.iter().enumerate() {
if ix != n.id {
bail!("Invalid node id: position is {}, node is {}", ix, n);
}
if seen.contains(&n.name) {
eprintln!("{}", self);
bail!("duplicate name {}", n.name);
}
seen.insert(&n.name);
}
Ok(())
}More examples
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
pub fn change_axes(
model: &TypedModel,
change: &AxisChange,
locked: &[OutletId],
bounds: &[TVec<OutletId>],
) -> TractResult<Option<(TypedModelPatch, TVec<(InOut, AxisOp)>)>> {
trace!("Considering change {:?}", change);
let mut todo_changes = vec![(change.clone(), None)];
let mut changed_wires = HashMap::new();
changed_wires.insert(change.outlet, change.op.clone());
let mut changed_ops: HashMap<usize, Box<dyn TypedOp>> = HashMap::new();
while let Some((c, emitter)) = todo_changes.pop() {
let outlets = if let Some(group) = bounds.iter().find(|b| b.contains(&c.outlet)) {
group.clone()
} else {
tvec![c.outlet]
};
for outlet in outlets {
if locked.contains(&outlet) {
trace!(" Change {:?} blocked by locked interface {:?}", change, outlet);
return Ok(None);
}
let mut nodes = vec![(outlet.node, InOut::Out(outlet.slot))];
for inlet in model.outlet_successors(outlet) {
nodes.push((inlet.node, InOut::In(inlet.slot)));
}
for (node_id, io) in nodes {
if Some(node_id) == emitter {
continue;
}
let node = model.node(node_id);
let more = node
.op
.change_axes(model, node, io, &c.op)
.with_context(|| format!("Propagating {:?} to node {}", change, node))?;
if more.is_none() {
trace!(" Propagation of {:?} blocked by {}", change, node);
return Ok(None);
}
let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap();
trace!(" Change {:?} enters {} from {:?}", c.op, node, io);
trace!(" propagates as {:?}", wire_changes);
if let Some(op) = substitute_op {
trace!(" replace op by {:?}", op);
changed_ops.insert(node.id, op);
}
for (wire, op) in wire_changes.into_iter() {
let outlet = wire.as_outlet(node);
match changed_wires.entry(outlet) {
Entry::Vacant(entry) => {
trace!(" {:?} {:?} change on {:?} is new", wire, op, outlet);
entry.insert(op.clone());
todo_changes.push((AxisChange { outlet, op }, Some(node_id)));
}
Entry::Occupied(previous) => {
if *previous.get() == op {
trace!(
" {:?} {:?} change on {:?} already done",
wire,
op,
outlet
);
} else {
trace!(
" {:?} {:?} change on {:?} conflicting with {:?}. Blocked.",
wire,
op,
outlet,
previous
);
return Ok(None);
}
}
}
}
}
}
}
trace!("Translating {:?} to patch", change);
let mut patch = TypedModelPatch::new(format!("{:?}", change));
let mut replaced_wires: HashMap<OutletId, OutletId> = HashMap::default();
let nodes_to_replace = changed_wires
.keys()
.map(|o| o.node)
.chain(changed_ops.keys().copied())
.collect::<std::collections::HashSet<usize>>();
for node_id in model.eval_order()? {
let node = model.node(node_id);
if nodes_to_replace.contains(&node_id) {
let mut inputs = tvec!();
for orig in &node.inputs {
let tgt = replaced_wires
.entry(*orig)
.or_insert_with(|| patch.tap_model(model, *orig).unwrap());
inputs.push(*tgt);
}
let op: Box<dyn TypedOp> =
changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone());
let new_wires = patch.wire_node(&node.name, op, &inputs)?;
if new_wires.len() == 1
&& patch.node(new_wires[0].node).op_is::<crate::ops::source::TypedSource>()
{
patch.inputs.insert(new_wires[0].node, node_id);
}
for (ix, w) in new_wires.iter().enumerate() {
replaced_wires.insert((node_id, ix).into(), *w);
}
} else {
for orig in &node.inputs {
if let Some(replacement) = replaced_wires.get(orig) {
patch.shunt_outside(model, *orig, *replacement)?;
}
}
}
}
for output in model.output_outlets()? {
if let Some(replacement) = replaced_wires.get(output) {
unsafe {
patch.shunt_outside_unchecked(*output, *replacement)?;
}
}
}
let mut interface_change = tvec!();
for (ix, input) in model.input_outlets()?.iter().enumerate() {
if let Some(change) = changed_wires.get(input) {
interface_change.push((InOut::In(ix), change.clone()));
}
}
for (ix, output) in model.output_outlets()?.iter().enumerate() {
if let Some(change) = changed_wires.get(output) {
interface_change.push((InOut::Out(ix), change.clone()));
}
}
debug_assert!(
patch.model.nodes.iter().map(|n| &n.name).collect::<std::collections::HashSet<_>>().len()
== patch.model.nodes.len()
);
Ok(Some((patch, interface_change)))
}source§impl<F, O> Graph<F, O>where
F: Fact + Clone + 'static + From<Arc<Tensor>> + Hash,
O: Debug + Display + From<Const> + AsRef<dyn Op> + AsMut<dyn Op> + Clone + Hash + 'static,
impl<F, O> Graph<F, O>where
F: Fact + Clone + 'static + From<Arc<Tensor>> + Hash,
O: Debug + Display + From<Const> + AsRef<dyn Op> + AsMut<dyn Op> + Clone + Hash + 'static,
sourcepub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor
) -> TractResult<OutletId>
pub fn add_const(
&mut self,
name: impl Into<String>,
v: impl IntoArcTensor
) -> TractResult<OutletId>
Examples found in repository?
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
fn declutter_pull_constant_outputs(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(slot) = mapping.last_value_slot {
if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
let inner_node = self.body.output_outlets()?[model_output_ix].node;
let inner_node = self.body.node(inner_node);
let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}More examples
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
for n in model.eval_order()? {
let node = model.node(n);
if node.op.is_stateless() && !node.op_is::<Const>() {
if let Some(inputs) = model
.node_input_facts(n)?
.iter()
.map(|f| f.konst.clone().map(|t| t.into_tvalue()))
.collect()
{
match node.op.eval(inputs) {
Ok(res) => {
for (ix, output) in res.into_iter().enumerate() {
let mut name = node.name.clone();
if ix > 0 {
name = format!("{}.{}", name, ix);
}
let wire = patch.add_const(name, output.into_arc_tensor())?;
patch.shunt_outside(model, (n, ix).into(), wire)?;
}
}
Err(e) => {
if !e.root_cause().is::<UndeterminedSymbol>() {
Err(e).with_context(|| {
format!("Eager eval {} during optimisation", model.node(n))
})?;
}
}
}
}
}
}
Ok(Some(patch).filter(|p| p.nodes.len() > 0))
}34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
ensure!(inputs[0].rank() == self.a.rank(), "Rank mismatch {:?} vs {:?}", inputs[0], self.a);
let mut model = TypedModel::default();
let t_a = self.a.offset_u8_as_i8();
let a = model.add_const("source_a", self.a.clone())?;
let b = model.add_const("source_b", inputs[0].clone().into_arc_tensor())?;
let bias = if let Some(bias) = self.bias.clone() {
Some(model.add_const("source_bias", bias)?)
} else {
None
};
let mut input_outlets = tvec![a];
for (i, t) in inputs.iter().enumerate().skip(1) {
input_outlets
.push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
}
let mut params = self.params.as_outlet_ids(
&mut model,
"qmatmul_unary",
&input_outlets,
self.a.datum_type(),
inputs[0].datum_type(),
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;
let new_op = MatMulUnary { a: t_a, axes: self.axes };
let result = model.wire_node("adhoc.matmul", new_op, &[b])?[0];
let result = wire_matmul_quant(
&mut model,
"adhoc",
a,
b,
bias,
self.axes,
result,
self.output_type,
¶ms,
)?;
model.set_output_outlets(&[result])?;
model.into_runnable()?.run(tvec![])
}
}
impl TypedOp for QMatMulUnary {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
if inputs.len() != 1 + self.params.input_count() {
bail!(
"Inconsistent q matmul unary. expects {} inputs, got {}",
1 + self.params.input_count(),
inputs.len()
);
}
if inputs[0].rank() != self.a.rank() {
bail!("Inconsistent matmul between {:?} and {:?} (rank mismatch)", inputs[0], self.a);
}
let (_m, _k, _n, c_shape) = compute_shape(
&self.a.shape().iter().map(|d| d.to_dim()).collect::<TVec<_>>(),
&inputs[0].shape,
self.axes,
)?;
#[allow(clippy::comparison_chain)]
if let Some(bias) = &self.bias {
if bias.rank() > 1 {
anyhow::bail!("Bias must be either scalar or vector (rank 0 or 1).");
} else if bias.rank() == 1 {
let expected_len = c_shape[self.axes.c_m].to_usize()?;
anyhow::ensure!(
bias.len() == expected_len,
"got: {:?} expected len: {:?}",
bias,
expected_len
);
};
}
Ok(tvec!(self.output_type.fact(c_shape)))
}
fn invariants(&self, inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<Invariants> {
/*
dbg!(inputs);
dbg!(&self.params);
*/
// FIXME: why ?
if self.params.iter().any(|qp| match qp.1 {
QParamKind::Attr(t) => t.len() > 1,
QParamKind::FromInput(ix) => !inputs[*ix].shape.volume().is_one(),
QParamKind::FromQType => false,
}) {
Ok(Invariants::none())
} else {
let mut invs =
super::mir_unary::mir_unary_invariants(inputs[0], outputs[0], self.axes)?;
for axis in &mut invs.axes {
axis.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
}
Ok(invs)
}
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
if let Some((a, axes, wire_changes)) =
super::mir_unary::mir_unary_change_axes(model, node, io, change, &self.axes, &self.a)?
{
let op = Self { axes, a: a.into_arc_tensor(), ..self.clone() };
Ok(Some(AxisChangeConsequence { substitute_op: Some(Box::new(op)), wire_changes }))
} else {
Ok(None)
}
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
use crate::ops::array::TypedConcat;
if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
{
let mut patch = TypedModelPatch::new("split over k-concatenated input");
let k_axis = self.axes.a_k;
if concat.axis == self.axes.b_k {
let concat_node = model.node(node.inputs[0].node);
let offsets = concat
.offsets(&model.node_input_facts(concat_node.id)?)?
.iter()
.map(|x| x.to_usize())
.collect::<TractResult<Vec<usize>>>()?;
let mut wires = vec![];
let mut params_for_split = self.params.clone();
params_for_split.a_scale = tensor0(1.0f32).into();
params_for_split.b_scale = tensor0(1.0f32).into();
params_for_split.c_scale = tensor0(1.0f32).into();
params_for_split.c0 = tensor0(0i32).into();
let input_outlets = node
.inputs
.iter()
.skip(1)
.map(|o| patch.tap_model(model, *o))
.collect::<TractResult<TVec<_>>>()?;
let params_outlets = self.params.as_outlet_ids(
&mut patch,
&node.name,
&input_outlets,
self.a.datum_type(),
model.node_input_facts(node.id)?[0].datum_type,
self.output_type,
)?;
let scale = combine_scales(
&mut patch,
&node.name,
params_outlets[1],
params_outlets[3],
params_outlets[5],
)?;
let c0 = params_outlets[4];
for (ix, input) in concat_node.inputs.iter().enumerate() {
let wire = patch.tap_model(model, *input)?;
let a = self.a.slice(k_axis, offsets[ix], offsets[ix + 1])?;
let wire = patch
.wire_node(
format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
Self {
a: a.into_arc_tensor(),
output_type: DatumType::I32,
bias: self.bias.clone().filter(|_| ix == 0),
params: params_for_split.clone(),
..self.clone()
},
&[wire],
)
.context("wiring new matmulunary")?[0];
wires.push(wire)
}
let mut wire = wires[0];
for (ix, w) in wires[1..].iter().enumerate() {
wire = patch.wire_node(
format!("{}.k-add-{}", node.name, ix),
crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
&[wire, *w],
)?[0];
}
wire = requant(&mut patch, &node.name, wire, self.output_type, scale, c0)?;
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
return Ok(Some(patch));
}
}
Ok(None)
}
fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
cost(self.a.shape(), &inputs[0].shape.to_tvec(), inputs[0].datum_type, self.axes)
}
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
let t_a = self.a.offset_u8_as_i8();
if let Some((inputs, qp)) = self.params.inline_static(model, node)? {
let mut patch = TypedModelPatch::new("inlining matmul quantized params");
let inputs: Vec<OutletId> =
inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<_>>()?;
let op = Self {
a: t_a,
params: MatMulQParams { a0: qp.a0.offset_u8_as_i8(&patch, &inputs)?, ..qp },
..self.clone()
};
let wire = patch.wire_node(&node.name, op, &inputs)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
return Ok(Some(patch));
}
let a = patch.wire_node(
format!("{}.a_const", &node.name),
ops::konst::Const(self.a.clone()),
&[],
)?[0];
let b = patch.tap_model(model, node.inputs[0])?;
let bias = if let Some(bias) = self.bias.clone() {
Some(patch.add_const(format!("{}.bias_const", &node.name), bias)?)
} else {
None
};
let mut input_outlets = tvec![a];
for i in node.inputs.iter().skip(1) {
input_outlets.push(patch.tap_model(model, *i)?)
}
let mut params = self.params.as_outlet_ids(
&mut patch,
&node.name,
&input_outlets,
self.a.datum_type(),
model.node_input_facts(node.id)?[0].datum_type,
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut params[2], "b0")?;
let new_op = MatMulUnary { a: t_a, axes: self.axes };
let result = patch.wire_node(format!("{}.matmul", &node.name), new_op, &[b])?[0];
let result = wire_matmul_quant(
&mut patch,
&node.name,
a,
b,
bias,
self.axes,
result,
self.output_type,
¶ms,
)?;
patch.shunt_outside(model, node.id.into(), result)?;
Ok(Some(patch))
}195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
pub fn as_outlet_ids(
&self,
model: &mut TypedModel,
node_name: &str,
inputs_wires: &[OutletId],
a_dt: DatumType,
b_dt: DatumType,
c_dt: DatumType,
) -> TractResult<TVec<OutletId>> {
let mut params_outlets = tvec!();
for (mut params, dt) in self.iter().chunks(2).into_iter().zip([a_dt, b_dt, c_dt].iter()) {
if let Some(qp) = dt.qparams() {
let (x0_name, x0) = params.next().unwrap();
let (x_scale_name, x_scale) = params.next().unwrap();
ensure!(
(matches!(x0, QParamKind::FromQType)
|| x0 == &QParamKind::Attr(rctensor0(qp.zp_scale().0)))
&& (matches!(x_scale, QParamKind::FromQType)
|| x_scale == &QParamKind::Attr(rctensor0(qp.zp_scale().1))),
);
let (zp, scale) = qp.zp_scale();
let zp = tensor0(zp);
let zp = model.add_const(format!("{}.{}", node_name, x0_name), zp)?;
let scale = tensor0(scale);
let scale = model.add_const(format!("{}.{}", node_name, x_scale_name), scale)?;
params_outlets.push(zp);
params_outlets.push(scale)
} else {
for (param_name, param) in params {
match param {
QParamKind::Attr(t) => params_outlets.push(
model.add_const(format!("{}.{}", node_name, param_name), t.clone())?,
),
QParamKind::FromInput(i) => params_outlets.push(inputs_wires[*i]),
QParamKind::FromQType => {
bail!("Param {} has no quantization parameters", param_name)
}
}
}
}
}
Ok(params_outlets)
}
}
#[derive(Debug, Clone, new, Hash)]
pub struct QMatMul {
pub axes: MatMulAxes,
pub output_type: DatumType,
pub params: MatMulQParams,
}
impl_dyn_hash!(QMatMul);
impl Op for QMatMul {
fn name(&self) -> Cow<str> {
"QMatMul".into()
}
op_as_typed_op!();
}
impl EvalOp for QMatMul {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
ensure!(
inputs[0].rank() == inputs[1].rank(),
"Rank mismatch {:?} vs {:?}",
inputs[0],
inputs[1]
);
let mut model = TypedModel::default();
let a = model.add_const("source_a", inputs[0].clone().into_arc_tensor())?;
let b = model.add_const("source_b", inputs[1].clone().into_arc_tensor())?;
let bias = model.add_const("source_bias", inputs[2].clone().into_arc_tensor())?;
let mut input_outlets = tvec![a, b, bias];
for (i, t) in inputs.iter().enumerate().skip(3) {
input_outlets
.push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
}
let mut params = self.params.as_outlet_ids(
&mut model,
"qmatmul_unary",
&input_outlets,
inputs[0].datum_type(),
inputs[1].datum_type(),
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;
let new_op = MatMul { axes: self.axes };
let result = model.wire_node("adhoc.matmul", new_op, &[a, b])?[0];
let result = wire_matmul_quant(
&mut model,
"adhoc",
a,
b,
Some(bias),
self.axes,
result,
self.output_type,
¶ms,
)?;
model.set_output_outlets(&[result])?;
model.into_runnable()?.run(tvec![])
}
}
impl TypedOp for QMatMul {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
if inputs.len() != 3 + self.params.input_count() {
bail!(
"Inconsistent q matmul. expects {} inputs, got {}",
3 + self.params.input_count(),
inputs.len()
);
}
if inputs[0].rank() != inputs[1].rank() {
bail!(
"Inconsistent matmul between {:?} and {:?} (rank mismatch)",
inputs[0],
inputs[1]
);
}
let (_m, _k, _n, c_shape) = compute_shape(&inputs[0].shape, &inputs[1].shape, self.axes)?;
let bias = &inputs[2];
#[allow(clippy::comparison_chain)]
if bias.rank() > 1 {
anyhow::bail!("Bias must be either scalar or vector (rank 0 or 1).");
} else if bias.rank() == 1 {
let expected_len = &c_shape[self.axes.c_m];
anyhow::ensure!(
&bias.shape[0] == expected_len,
"got: {:?} expected len: {:?}",
bias,
expected_len
);
};
Ok(tvec!(self.output_type.fact(c_shape)))
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let a_fact = model.outlet_fact(node.inputs[0])?;
let b_fact = model.outlet_fact(node.inputs[1])?;
let bias_fact = model.outlet_fact(node.inputs[2])?;
if bias_fact.konst.is_none() {
return Ok(None);
}
let konst_ix = if a_fact.konst.is_some() {
0
} else if b_fact.konst.is_some() {
1
} else {
return Ok(None);
};
let flip = konst_ix == 1;
let konst = model.outlet_fact(node.inputs[konst_ix])?.konst.as_ref().unwrap();
let bias = model.outlet_fact(node.inputs[2])?.konst.clone().unwrap();
let inputs: Vec<_> = node
.inputs
.iter()
.enumerate()
.filter_map(|(i, out_id)| if i == konst_ix || i == 2 { None } else { Some(*out_id) })
.collect();
let new_params = {
let mut qp = self.params.clone();
//compensate for the removed parameter
for (_, a) in qp.iter_mut() {
if let QParamKind::FromInput(i) = a {
*i -= 2
}
}
if flip {
MatMulQParams {
a0: qp.b0,
a_scale: qp.b_scale,
b0: qp.a0,
b_scale: qp.a_scale,
..qp
}
} else {
qp
}
};
let axes = if flip {
MatMulAxes {
a_m: self.axes.b_n,
a_k: self.axes.b_k,
b_n: self.axes.a_m,
b_k: self.axes.a_k,
c_m: self.axes.c_n,
c_n: self.axes.c_m,
}
} else {
self.axes
};
TypedModelPatch::replace_single_op(
model,
node,
&inputs,
QMatMulUnary::new(
konst.clone(),
// if bias is uniformly zero, it can be discarded
Some(bias).filter(|b| {
b.as_uniform()
.map(|b| b.cast_to_scalar::<f32>().unwrap() != 0.0)
.unwrap_or(true)
}),
axes,
self.output_type,
new_params,
),
)
.map(Some)
}
fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
cost(
&inputs[0].shape.to_tvec(),
&inputs[1].shape.to_tvec(),
inputs[0].datum_type,
self.axes,
)
}
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
if let Some((inputs, qp)) = self.params.inline_static(model, node)? {
let mut patch = TypedModelPatch::new("inlining matmul quantized params");
let inputs: Vec<OutletId> =
inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<_>>()?;
let op = Self { params: qp, ..self.clone() };
let wire = patch.wire_node(&node.name, op, &inputs)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
return Ok(Some(patch));
}
let a = patch.tap_model(model, node.inputs[0])?;
let b = patch.tap_model(model, node.inputs[1])?;
let bias = patch.tap_model(model, node.inputs[2])?;
let mut input_outlets = tvec![a, b, bias];
for i in node.inputs.iter().skip(3) {
input_outlets.push(patch.tap_model(model, *i)?)
}
let mut params = self.params.as_outlet_ids(
&mut patch,
&node.name,
&input_outlets,
model.node_input_facts(node.id)?[0].datum_type,
model.node_input_facts(node.id)?[1].datum_type,
self.output_type,
)?;
let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut params[0], "a0")?;
let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut params[2], "b0")?;
let new_op = MatMul { axes: self.axes };
let result = patch.wire_node(format!("{}.matmul", &node.name), new_op, &[a, b])?[0];
let result = wire_matmul_quant(
&mut patch,
&node.name,
a,
b,
Some(bias),
self.axes,
result,
self.output_type,
¶ms,
)?;
patch.shunt_outside(model, node.id.into(), result)?;
Ok(Some(patch))
}
as_op!();
}
/// Wires the offsetting of a matrix and zero point node.
///
/// Only wires nodes of u8 type and leaves nodes of different type untouched.
pub(crate) fn wire_offset_u8_as_i8(
model: &mut TypedModel,
model_name: &str,
matrix: OutletId,
matrix_name: &str,
zero_point: &mut OutletId,
zero_point_name: &str,
) -> TractResult<OutletId> {
let fact = model.outlet_fact(matrix)?;
if let DatumType::U8 = fact.datum_type.unquantized() {
match model.outlet_fact(*zero_point)?.datum_type.unquantized() {
DatumType::U8 => {
*zero_point = model.wire_node(
format!("{}.offset_{}_as_i8", model_name, zero_point_name),
ops::quant::offset_u8_as_i8(),
&[*zero_point],
)?[0];
}
DatumType::I32 => {
let zp_rank = model.outlet_fact(*zero_point)?.rank();
let cst = model.add_const(
format!("{}.offset_{}_as_i8.min", model_name, zero_point_name),
tensor0(-128i32).broadcast_into_rank(zp_rank)?.into_arc_tensor(),
)?;
*zero_point = model.wire_node(
format!("{}.offset_{}_as_i8", model_name, zero_point_name),
ops::math::add(),
&[*zero_point, cst],
)?[0];
}
_ => (),
}
Ok(model.wire_node(
format!("{}.offset_{}_as_i8", model_name, matrix_name),
ops::quant::offset_u8_as_i8(),
&[matrix],
)?[0])
} else {
Ok(matrix)
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn wire_matmul_quant(
model: &mut TypedModel,
name: &str,
a: OutletId,
b: OutletId,
bias: Option<OutletId>,
axes: MatMulAxes,
mut result: OutletId,
output_type: DatumType,
params: &[OutletId],
) -> TractResult<OutletId> {
let b_fact = model.outlet_fact(b)?.clone();
// TODO: assumed c_rank == b_rank (== a_rank)
if let Some(mut bias) = bias {
// bias is scalar -> ok
// bias is vec, m is right in C -> broadcast will add left side axes to bias
// bias is vec, m is not right in C -> we must append in C axes to the right to align them
let bias_rank = model.outlet_fact(bias)?.rank();
if bias_rank == 1 && axes.c_m < b_fact.rank() - 1 {
for i in 0..(b_fact.rank() - axes.c_m - 1) {
bias = model.wire_node(
format!("{}.axis_rank_fix.{}", name, i),
AxisOp::Add(bias_rank + i),
&[bias],
)?[0]
}
}
result = wire_with_rank_broadcast(
&format!("{}.add_bias", &name),
model,
ops::math::add(),
&[result, bias],
)?[0];
}
let k = model.outlet_fact(a)?.shape[axes.a_k].clone();
let abc_scale = combine_scales(model, name, params[1], params[3], params[5])?;
let a_i32 =
model.wire_node(format!("{}.a_as_i32", name), ops::cast::cast(i32::datum_type()), &[a])?[0];
let b_i32 =
model.wire_node(format!("{}.b_as_i32", name), ops::cast::cast(i32::datum_type()), &[b])?[0];
let sum_a = model.wire_node(
format!("{}.sum_a", name),
ops::nn::Reduce::new(tvec!(axes.a_k), ops::nn::Reducer::Sum),
&[a_i32],
)?[0];
let sum_a =
model.wire_node(format!("{}.sum_a_reduced", name), AxisOp::Rm(axes.a_k), &[sum_a])?[0];
let sum_b = model.wire_node(
format!("{}.sum_b", name),
ops::nn::Reduce::new(tvec!(axes.b_k), ops::nn::Reducer::Sum),
&[b_i32],
)?[0];
let sum_b =
model.wire_node(format!("{}.sum_b_reduced", name), AxisOp::Rm(axes.b_k), &[sum_b])?[0];
let result = compensate_zero_points(
model, name, result, k, params[0], params[2], sum_a, sum_b, axes.c_m, axes.c_n,
)?;
requant(model, name, result, output_type, abc_scale, params[4])
}
pub(crate) fn combine_scales(
model: &mut TypedModel,
name: &str,
a_scale: OutletId,
b_scale: OutletId,
c_scale: OutletId,
) -> TractResult<OutletId> {
let ab_scale = wire_with_rank_broadcast(
&format!("{}.ab_scale", name),
model,
ops::math::mul(),
&[a_scale, b_scale],
)?[0];
let abc_scale = wire_with_rank_broadcast(
&format!("{}.abc_scales", name),
model,
ops::math::div(),
&[ab_scale, c_scale],
)?[0];
Ok(abc_scale)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn compensate_zero_points(
model: &mut TypedModel,
name: &str,
result: OutletId,
k: TDim,
a0: OutletId,
b0: OutletId,
sum_a: OutletId,
sum_b: OutletId,
m_axis: usize,
n_axis: usize,
) -> TractResult<OutletId> {
let input_shape = model.outlet_fact(result)?.shape.clone();
let rank = model.outlet_fact(result)?.rank();
debug_assert_eq!(model.outlet_fact(sum_a)?.rank(), rank - 1);
debug_assert_eq!(model.outlet_fact(sum_b)?.rank(), rank - 1);
// make sum_a into from a 1D vector to a vertical matrix, sum_b horizontal
// switch shapes if c_trans
let sum_a =
model.wire_node(format!("{}.reshape_sum_a", name), AxisOp::Add(n_axis), &[sum_a])?[0];
let sum_b =
model.wire_node(format!("{}.reshape_sum_b", name), AxisOp::Add(m_axis), &[sum_b])?[0];
debug_assert_eq!(
model.outlet_fact(sum_a)?.shape[m_axis],
model.outlet_fact(result)?.shape[m_axis]
);
debug_assert_eq!(
model.outlet_fact(sum_b)?.shape[n_axis],
model.outlet_fact(result)?.shape[n_axis]
);
let a0 =
model.wire_node(format!("{}.cast_a0", name), ops::cast::cast(i32::datum_type()), &[a0])?[0];
let b0 =
model.wire_node(format!("{}.cast_b0", name), ops::cast::cast(i32::datum_type()), &[b0])?[0];
let k = model.add_const(format!("{}.k", name), rctensor0(k))?;
let k =
model.wire_node(format!("{}.cast_k", name), ops::cast::cast(i32::datum_type()), &[k])?[0];
let a0_sum_b = wire_with_rank_broadcast(
&format!("{}.a0_sum_b", name),
model,
ops::math::mul(),
&[a0, sum_b],
)?[0];
let b0_sum_a = wire_with_rank_broadcast(
&format!("{}.b0_sum_a", name),
model,
ops::math::mul(),
&[b0, sum_a],
)?[0];
let a0_k =
wire_with_rank_broadcast(&format!("{}.a0_k", name), model, ops::math::mul(), &[a0, k])?[0];
let a0_k_b0 = wire_with_rank_broadcast(
&format!("{}.a0_k_b0", name),
model,
ops::math::mul(),
&[a0_k, b0],
)?[0];
let result = wire_with_rank_broadcast(
&format!("{}.minus_a0_B", &name),
model,
ops::math::sub(),
&[result, a0_sum_b],
)?[0];
let result = wire_with_rank_broadcast(
&format!("{}.minus_b0_A", &name),
model,
ops::math::sub(),
&[result, b0_sum_a],
)?[0];
let result = wire_with_rank_broadcast(
&format!("{}.plus_a0_k_b0", &name),
model,
ops::math::add(),
&[result, a0_k_b0],
)?[0];
debug_assert_eq!(model.outlet_fact(result)?.shape, input_shape);
Ok(result)
}
pub(crate) fn requant(
model: &mut TypedModel,
name: &str,
wire: OutletId,
dt: DatumType,
scale: OutletId,
zero_point: OutletId,
) -> TractResult<OutletId> {
let wire = wire_with_rank_broadcast(
&format!("{}.scale", name),
model,
ops::quant::scale(),
&[scale, wire],
)?[0];
let zero_point = model.wire_node(
format!("{}.cast_c0", name),
ops::cast::cast(i32::datum_type()),
&[zero_point],
)?[0];
let wire = wire_with_rank_broadcast(
&format!("{}.zeropoint", name),
model,
ops::math::add(),
&[wire, zero_point],
)?[0];
clamp_and_cast_to(model, name, dt, wire)
}
pub(crate) fn clamp_and_cast_to(
model: &mut TypedModel,
name: &str,
dt: DatumType,
wire: OutletId,
) -> TractResult<OutletId> {
if dt == i32::datum_type() {
return Ok(wire);
}
let rank = model.outlet_fact(wire)?.rank();
let inf = dt
.unquantized()
.min_value()
.cast_to_dt(DatumType::I32)?
.into_owned()
.broadcast_into_rank(rank)?
.into_arc_tensor();
let inf = model.add_const(format!("{}.min.const", name), inf)?;
let sup = dt
.unquantized()
.max_value()
.cast_to_dt(DatumType::I32)?
.into_owned()
.broadcast_into_rank(rank)?
.into_arc_tensor();
let sup = model.add_const(format!("{}.max.const", name), sup)?;
let wire = model.wire_node(format!("{}.min", name), ops::math::min(), &[wire, sup])?;
let wire = model.wire_node(format!("{}.max", name), ops::math::max(), &[wire[0], inf])?;
let wire = model.wire_node(format!("{}.cast", name), ops::cast::cast(dt), &wire)?;
Ok(wire[0])
}230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
fn declutter_mul(
_op: &Mul,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(p) = declutter_neutral(model, node, 1, true).context("decluttering neutral")? {
return Ok(Some(p));
}
if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
let var_fact = model.outlet_fact(uniform.var)?;
if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
let shapes =
model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
let shape: ShapeFact =
crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
return Ok(Some(TypedModelPatch::rewire(
model,
&[],
&[node.id.into()],
&|patch, _| {
let scalar =
patch.add_const(format!("{}.zero", node.name), uniform.uni.clone())?;
let op = MultiBroadcastTo::new(shape.clone());
patch.wire_node(&node.name, op, &[scalar])
},
)?));
}
let dt = uniform.uni.datum_type();
let integer = uniform.uni.cast_to_scalar::<i64>()?;
if tensor0(integer)
.cast_to_dt(uniform.uni.datum_type())?
.close_enough(&uniform.uni, false)
.is_ok()
&& dt.is_integer()
&& uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
{
let shift = integer.trailing_zeros();
return Ok(Some(TypedModelPatch::rewire(
model,
&[uniform.var],
&[node.id.into()],
&|patch, taps| {
let shift = patch.add_const(
format!("{}.shift", node.name),
tensor0(shift)
.cast_to_dt(dt)?
.into_owned()
.broadcast_into_rank(var_fact.rank())?,
)?;
patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
},
)?));
}
}
Ok(None)
}
fn declutter_div(
_op: &Div,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(p) = declutter_neutral(model, node, 1, false)? {
return Ok(Some(p));
}
if let &[p, q] = &*model.node_input_facts(node.id)? {
if let Some(q) = &q.uniform {
let dt = q.datum_type();
if let Ok(integer) = q.cast_to_scalar::<i64>() {
if tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
&& dt.is_integer()
&& q.cast_to_scalar::<i64>()?.count_ones() == 1
{
let shift = integer.trailing_zeros();
return Ok(Some(TypedModelPatch::rewire(
model,
&[node.inputs[0]],
&[node.id.into()],
&|patch, taps| {
let shift = patch.add_const(
format!("{}.shift", node.name),
tensor0(shift)
.cast_to_dt(dt)?
.into_owned()
.broadcast_into_rank(p.rank())?,
)?;
patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
},
)?));
}
}
if dt.is_float() {
return Ok(Some(TypedModelPatch::rewire(
model,
&node.inputs,
&[node.id.into()],
&|patch, taps| {
let q =
patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?
[0];
patch.wire_node(&node.name, mul(), &[taps[0], q])
},
)?));
}
}
}
Ok(None)
}108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
fn kernel_offset_u8_as_i8(
&self,
inputs: &mut [OutletId],
model: &mut TypedModel,
) -> TractResult<Option<Self>> {
if let DatumType::U8 = self.kernel.datum_type().unquantized() {
let new_op = Self {
kernel: self.kernel.offset_u8_as_i8(),
q_params: self
.q_params
.as_ref()
.map(|(dt, qp)| -> TractResult<_> {
let a0 = match &qp.a0 {
QParamKind::Attr(_) | QParamKind::FromQType => {
qp.a0.offset_u8_as_i8(model, &[])?
}
QParamKind::FromInput(i) => {
match model.outlet_fact(inputs[*i])?.datum_type.unquantized() {
DatumType::U8 => {
inputs[*i] = model.wire_node(
format!(
"{}.offset_{}_as_i8",
model.node(inputs[*i].node).name,
"a0"
),
ops::quant::offset_u8_as_i8(),
&[inputs[*i]],
)?[0];
}
DatumType::I32 => {
let cst = model.add_const(
format!(
"{}.offset_{}_as_i8.cst",
&model.node(inputs[*i].node).name,
"a0"
),
rctensor0(-128i32),
)?;
inputs[*i] = model.wire_node(
format!(
"{}.offset_{}_as_i8",
model.node(inputs[*i].node).name,
"a0"
),
ops::math::add(),
&[inputs[*i], cst],
)?[0];
}
_ => (),
}
QParamKind::FromInput(*i)
}
};
Ok((*dt, MatMulQParams { a0, ..qp.clone() }))
})
.transpose()?,
..self.clone()
};
Ok(Some(new_op))
} else {
Ok(None)
}
}
fn bias_as_non_linear<T>(&self) -> TractResult<ArrayD<Vec<ProtoFusedSpec>>>
where
T: Datum + Copy,
{
let mut ops = Array1::from_elem(self.group, vec![]);
if let Some(bias) = &self.bias {
let bias = bias.cast_to::<T>()?;
let bias = bias.as_slice::<T>()?;
ops.iter_mut().zip(bias.chunks(self.output_channels() / self.group)).for_each(
|(ops, bias)| {
ops.push(ProtoFusedSpec::BinPerRow(
rctensor1(bias).into(),
tract_linalg::mmm::BinOp::Add,
));
},
)
}
let mut ops = ops.into_dyn();
if self.group == 1 {
ops.index_axis_inplace(Axis(0), 0);
}
if self.pool_spec.data_format.has_n() {
ops.insert_axis_inplace(Axis(0));
}
Ok(ops)
}
pub unsafe fn wire_as_quant_im2col(
&self,
model: &mut TypedModel,
name: &str,
b_dt: DatumType,
wires: &[OutletId],
) -> TractResult<OutletId> {
use crate::ops::matmul::mir_quant as qmm;
let c_dt = self.q_params.as_ref().unwrap().0;
let params = self.q_params.as_ref().unwrap().1.as_outlet_ids(
model,
name,
wires,
self.kernel.datum_type(),
b_dt,
c_dt,
)?;
let a0 = params[0];
let a_scale = params[1];
let mut b0 = params[2];
let b_scale = params[3];
let c0 = params[4];
let c_scale = params[5];
let b = wire_offset_u8_as_i8(model, name, wires[0], "b", &mut b0, "b0")?;
let b_fact = model.outlet_fact(b)?.clone();
let (_, m, k, n, mmm) = self.compute_geo(&b_fact)?;
let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?;
let im2col = model.wire_node(
format!("{}.im2col", name),
Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
&[b, b0],
)?[0];
let a = self.kernel_as_group_o_ihw()?.into_tensor();
let a = a.cast_to_dt(i32::datum_type())?;
let a = a.to_array_view::<i32>()?;
let mut sum_a = a.sum_axis(Axis(a.ndim() - 1));
if self.group == 1 {
sum_a.index_axis_inplace(Axis(0), 0);
}
if self.pool_spec.data_format.has_n() {
sum_a.insert_axis_inplace(Axis(0));
}
let sum_a = model.add_const(format!("{}.sum_a", name), sum_a)?;
let mut sum_b = model.wire_node(
format!("{}.sum_b", name),
super::QSumB { n: n.clone(), r: mmm.b_pack().panel_width(), k },
&[im2col],
)?[0];
if self.group > 1 && self.pool_spec.data_format.c_is_last() {
let has_n = self.pool_spec.data_format.has_n() as usize;
sum_b = model.wire_node(
format!("{}.transpose_sum_b", name),
AxisOp::Move(has_n, 1 + has_n),
&[sum_b],
)?[0];
}
let b_dt = model.outlet_fact(b)?.datum_type;
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
let mut geometry = MatMulGeometry::from(SymbolicMatMulGeometry {
b_datum_type: b_dt,
m: m.to_dim(),
k: k.to_dim(),
n: n.clone(),
mmm: mmm.clone(),
});
if n.to_usize().is_ok() {
geometry = geometry.optimize_if(Some(&SymbolValues::default()))?;
}
let wire = self.wire_lir_matmatmul(
model,
name,
im2col,
mmm,
i32::datum_type(),
mmm_output_shape.clone().into(),
m,
k,
geometry,
c_axis,
h_axis,
)?;
let has_n = self.pool_spec.data_format.has_n() as usize;
let has_group = (self.group > 1) as usize;
let (m_axis, n_axis) = if self.pool_spec.data_format.c_is_last() {
(1 + has_group + has_n, has_n)
} else {
(has_group + has_n, 1 + has_n + has_group)
};
let wire = qmm::compensate_zero_points(
model,
name,
wire,
k.to_dim(),
a0,
b0,
sum_a,
sum_b,
m_axis,
n_axis,
)?;
let mut wire = qmm::requant(model, name, wire, c_dt, abc_scale, c0)?;
if self.group > 1 {
wire = model.wire_node(
format!("{}.reshape_group", name),
AxisOp::Reshape(
c_axis - 1,
mmm_output_shape[c_axis - 1..][..2].iter().map(|d| d.to_dim()).collect(),
tvec!((m * self.group).to_dim()),
),
&[wire],
)?[0];
}
let wire = Self::wire_geo_reshape(model, name, wire, &output_shape)?;
Ok(wire)
}
pub unsafe fn wire_as_im2col_pair(
&self,
model: &mut TypedModel,
name: &str,
mut wire: OutletId,
) -> TractResult<OutletId> {
let b_fact = model.outlet_fact(wire)?.clone();
let b_dt = b_fact.datum_type;
let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);
let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
let (_, m, k, n, mmm) = self.compute_geo(model.outlet_fact(wire)?)?;
let padding = model.add_const(format!("{}.b0", name), Tensor::zero_dt(b_dt, &[])?)?;
wire = model.wire_node(
format!("{}.im2col", name),
Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
&[wire, padding],
)?[0];
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
let mut geometry = MatMulGeometry::from(SymbolicMatMulGeometry {
b_datum_type: b_dt,
m: m.to_dim(),
k: k.to_dim(),
n: n.clone(),
mmm: mmm.clone(),
});
if n.to_usize().is_ok() {
geometry = geometry.optimize_if(Some(&SymbolValues::default()))?;
}
let mut wire = self.wire_lir_matmatmul(
model,
name,
wire,
mmm,
c_dt,
mmm_output_shape.clone().into(),
m.to_usize().unwrap(),
k.to_usize().unwrap(),
geometry,
c_axis,
h_axis,
)?;
if self.group > 1 {
wire = model.wire_node(
format!("{}.reshape_group", name),
AxisOp::Reshape(
c_axis - 1,
mmm_output_shape[c_axis - 1..][..2].iter().map(|d| d.to_dim()).collect(),
tvec!((m * self.group).to_dim()),
),
&[wire],
)?[0];
}
let wire = Self::wire_geo_reshape(model, name, wire, &output_shape)?;
Ok(wire)
}
fn mmm_output_shape<D: DimLike>(
&self,
output_shape: &BaseDataShape<D, TVec<D>>,
) -> TractResult<(TVec<D>, usize, usize)> {
let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.from_n_c_hw(
output_shape.n().cloned().unwrap_or_else(|| 1.into()),
output_shape.c().clone(),
tvec!(geo_collapsed_out),
)?;
let mut mmm_output_shape: TVec<D> = shape.shape.clone();
let mut c_axis = shape.c_axis();
let mut h_axis = shape.h_axis();
if self.group > 1 {
mmm_output_shape[shape.c_axis()] =
mmm_output_shape[shape.c_axis()].clone() / self.group;
mmm_output_shape.insert(shape.c_axis(), self.group.into());
if self.group > 1 {
if h_axis > c_axis {
h_axis += 1;
}
c_axis += 1;
}
}
Ok((mmm_output_shape, c_axis, h_axis))
}
fn wire_geo_reshape<D: DimLike>(
model: &mut TypedModel,
name: &str,
wire: OutletId,
output_shape: &BaseDataShape<D, TVec<D>>,
) -> TractResult<OutletId> {
let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
let wire = model.wire_node(
name,
AxisOp::Reshape(
output_shape.h_axis(),
tvec!(geo_collapsed_out.to_dim()),
output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
),
&[wire],
)?;
Ok(wire[0])
}
pub unsafe fn wire_as_lazy_im2col(
&self,
model: &mut TypedModel,
name: &str,
mut wire: OutletId,
) -> TractResult<OutletId> {
let mut b_fact = model.outlet_fact(wire)?.clone();
let (geo, m, k, n, mmm) = self.compute_geo(&b_fact)?;
let input_shape = b_fact.shape.as_concrete().unwrap().to_vec();
let mut geo = geo.to_concrete(&input_shape)?.into_owned();
let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
let mut pads = vec![(0, 0); b_fact.rank()];
for (ix, ax) in padding.iter().enumerate() {
pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
}
let op = crate::ops::array::Pad {
mode: crate::ops::array::PadMode::Constant(
Tensor::zero_scalar_dt(b_fact.datum_type)?.into_arc_tensor(),
),
pads,
};
wire = model.wire_node(format!("{}.pad", name), op, &[wire])?[0];
let valid_pool_spec =
PoolSpec { padding: ops::cnn::PaddingSpec::Valid, ..self.pool_spec.clone() };
b_fact = model.outlet_fact(wire)?.clone();
let concrete_shape = b_fact.shape.as_concrete().unwrap();
input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
geo = valid_pool_spec
.compute_geo(&b_fact.shape)?
.to_concrete(concrete_shape)?
.into_owned();
}
let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);
let c_stride = input_shape.c_stride();
let size_of_b = b_fact.datum_type.size_of() as isize;
let n_bytes_offsets: Vec<isize> =
geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
let k_bytes_offsets: Vec<isize> = (0..self.input_channels())
.flat_map(|ici| {
geo.patch
.standard_layout_data_field
.iter()
.map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
})
.collect();
let virtual_input = super::lazy_im2col::LazyIm2colSpec { n_bytes_offsets, k_bytes_offsets };
let b_storage = mmm.b_virtual_input(Box::new(virtual_input), k);
let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;
let geometry = MatMulGeometry::Concrete(ConcreteMatMulGeometry {
m,
k,
n: n.to_usize().unwrap(),
b_storage,
});
let wire = self.wire_lir_matmatmul(
model,
name,
wire,
mmm,
c_dt,
mmm_output_shape.into(),
m.to_usize().unwrap(),
k,
geometry,
c_axis,
h_axis,
)?;
let wire = Self::wire_geo_reshape(model, name, wire, &geo.output_shape)?;
Ok(wire)
}
#[allow(clippy::type_complexity)]
fn compute_geo(
&self,
input_fact: &TypedFact,
) -> TractResult<(PoolGeometry, usize, usize, TDim, Box<dyn MatMatMul>)> {
let a_dt = self.kernel.datum_type();
let b_dt = input_fact.datum_type;
let c_dt = crate::ops::matmul::output_type(b_dt);
let geo = self.pool_spec.compute_geo(&input_fact.shape)?;
trace!("output channels: {:?}", self.output_channels());
let m = self.output_channels() / self.group;
let k = self.kernel.len() / self.output_channels();
let n: TDim =
self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();
let mmm = tract_linalg::ops()
.mmm(a_dt, b_dt, c_dt, Some(m), Some(k), n.to_usize().ok())
.with_context(|| format!("No multiplier for {:?}x{:?} to {:?}", a_dt, b_dt, c_dt,))?;
Ok((geo, m, k, n, mmm))
}
#[allow(clippy::too_many_arguments)]
fn wire_lir_matmatmul(
&self,
model: &mut TypedModel,
name: &str,
wire: OutletId,
mmm: Box<dyn MatMatMul>,
c_datum_type: DatumType,
mmm_output_shape: ShapeFact,
m: usize,
k: usize,
geometry: MatMulGeometry,
c_m_axis: usize,
c_n_axis: usize,
) -> TractResult<OutletId> {
let kernels = self.kernel_as_packed_as(&mmm.a_pack(), k, m)?;
let shape = kernels.shape();
let mut fused_ops = dispatch_copy!(Self::bias_as_non_linear(mmm.internal_type())(self))?;
for fo in &mut fused_ops {
fo.push(ProtoFusedSpec::Store);
}
let mut iter = kernels.iter().cloned().zip(fused_ops.iter().cloned());
let micro_ops = ArrayD::from_shape_fn(shape, |_| iter.next().unwrap());
let wire = model.wire_node(
format!("{}.matmatmul", name),
LirMatMulUnary {
c_fact: c_datum_type.fact(mmm_output_shape.clone()),
micro_ops,
c_m_axis,
c_n_axis,
c_final_shape: mmm_output_shape,
reshape_post: vec![],
geometry,
mmm,
},
&[wire],
)?[0];
Ok(wire)
}
pub fn to_depth_wise<T>(&self, input: &TypedFact) -> TractResult<Box<dyn TypedOp>>
where
T: Datum + Clone + ::ndarray::LinalgScalar + PartialEq + Sum,
{
let input_shape = input.shape.as_concrete().unwrap();
let ConcretePoolGeometry { input_shape, patch, output_shape } =
self.pool_spec.compute_geo(&input.shape)?.to_concrete(input_shape)?.into_owned();
let bias = if let Some(b) = &self.bias {
b.clone()
} else {
Tensor::zero::<T>(&[*input_shape.c()])?.into_arc_tensor()
};
let op = DepthWise::new(
patch,
input_shape,
output_shape,
self.kernel_as_group_o_ihw().context("in kernel_as_group_o_ihw")?,
bias,
);
Ok(Box::new(op))
}
fn declutter_stride_slice_to_downsample(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let input_fact = model.outlet_fact(node.inputs[0])?;
let spatial_rank = self.kernel.rank() - 2;
if let Some(axis) = (0..spatial_rank).find(|&ax| {
self.pool_spec.stride(ax) > 1
&& (self.pool_spec.kernel_shape[ax] == 1
|| (self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
&& self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0))
}) {
let downsample_factor = self.pool_spec.stride(axis);
let mut new_op = self.clone();
if new_op.pool_spec.dilation(axis) > 1 {
new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
}
new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
let mut patch = TypedModelPatch::default();
let tap = patch.tap_model(model, node.inputs[0])?;
let shape = self
.pool_spec
.data_format
.shape(input_fact.shape.iter().collect::<TVec<TDim>>())?;
let down = patch.wire_node(
format!("{}.downsample.{}", node.name, axis),
crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
&[tap],
)?;
let id = patch.wire_node(&*node.name, new_op, &down)?[0];
patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
return Ok(Some(patch));
}
Ok(None)
}
fn declutter_as_matmul(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
use crate::ops::matmul::*;
let input_fact = model.outlet_fact(node.inputs[0])?;
let full_input_shape = input_fact.shape.to_tvec();
let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
if input_shape.hw_rank() == 1
&& self.group == 1
&& self.pool_spec.stride(0) == 1
&& self.kernel.len() == self.input_channels() * self.output_channels()
{
let ci = self.input_channels();
let co = self.output_channels();
let ker = self.kernel.clone().into_tensor();
let (a_shape, a_trans) = if self.kernel_fmt == KernelFormat::HWIO {
([ci, co], true)
} else {
([co, ci], false)
};
let a = ker
.into_shape(&a_shape)?
.broadcast_into_rank(full_input_shape.len())?
.into_arc_tensor();
let trans_data = self.pool_spec.data_format == DataFormat::HWC
|| self.pool_spec.data_format == DataFormat::NHWC;
let mut patch = TypedModelPatch::new("declutter_as_matmul");
let a = patch.add_const(format!("{}.filters", &node.name), a)?;
let mut inputs = node
.inputs
.iter()
.map(|i| patch.tap_model(model, *i))
.collect::<TractResult<TVec<_>>>()?;
inputs.insert(0, a);
let axes = MatMulAxes::default_for_rank(full_input_shape.len())
.transposing(a_trans, trans_data, trans_data);
// in Q case, the bias has to be injected inside the QMatMul (as it
// must be added before requantization)
let wire = if let Some(q_params) = &self.q_params {
let mut params = q_params.1.clone();
params.insert_input(0); // kernel as input
params.insert_input(2); // bias as input
let bias = self.bias.clone().unwrap_or_else(|| rctensor0(0i32));
anyhow::ensure!(bias.rank() == 0 || bias.rank() == 1);
let bias = patch.add_const(format!("{}.bias", &node.name), bias)?;
inputs.insert(2, bias);
let op = QMatMul { axes, output_type: q_params.0, params: q_params.1.clone() };
patch.wire_node(&*node.name, op, &inputs)?[0]
} else {
let op = MatMul { axes };
let mut wire = patch.wire_node(format!("{}.matmul", node.name), op, &inputs)?[0];
if let Some(b) = self.bias.as_ref().filter(|_| self.q_params.is_none()) {
anyhow::ensure!(b.rank() == 0 || b.rank() == 1);
let mut bias_shape = tvec!(1; input_shape.rank());
bias_shape[input_shape.c_axis()] = co;
let b = b.clone().into_tensor().into_shape(&bias_shape)?;
let b =
patch.add_const(format!("{}.bias.cst", node.name), b.into_arc_tensor())?;
wire = patch.wire_node(
format!("{}.bias", node.name),
crate::ops::math::add(),
&[wire, b],
)?[0];
}
wire
};
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
return Ok(Some(patch));
}
Ok(None)
}
fn declutter_precursor_padding(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.pool_spec.padding != PaddingSpec::Valid
&& !matches!(self.pool_spec.padding, PaddingSpec::Explicit(_, _, _))
{
return Ok(None);
}
let prec = model.node(node.inputs[0].node);
let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
let value = if let PadMode::Constant(c) = &pad.mode {
c
} else {
return Ok(None);
};
let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
if value.cast_to_scalar::<i64>()? != 0
|| (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
|| pad.pads[shape.c_axis()] != (0, 0)
{
return Ok(None);
}
let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
if let PaddingSpec::Explicit(bef, aft, false) = &self.pool_spec.padding {
izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
}
let padding = PaddingSpec::Explicit(before, after, false);
let mut new = self.clone();
new.pool_spec.padding = padding;
let mut patch = TypedModelPatch::default();
let wire = patch.tap_model(model, prec.inputs[0])?;
let wire = patch.wire_node(&node.name, new, &[wire])?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
}
}
impl Op for ConvUnary {
fn name(&self) -> Cow<str> {
"ConvUnary".into()
}
fn info(&self) -> TractResult<Vec<String>> {
let mut info = self.pool_spec.info();
info.push(format!(
"Kernel {:?} (groups:{}), {:?}",
self.kernel_fmt, self.group, self.kernel
));
if let Some(b) = &self.bias {
info.push(format!("Bias: {:?}", b))
}
Ok(info)
}
fn validation(&self) -> Validation {
Validation::Rounding
}
op_as_typed_op!();
}
impl EvalOp for ConvUnary {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let mut model = TypedModel::default();
let mut wires: TVec<OutletId> = inputs
.iter()
.enumerate()
.map(|(ix, v)| {
model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
})
.collect::<TractResult<_>>()?;
let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
let wire = unsafe {
if self.q_params.is_some() {
let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
op_ref.wire_as_quant_im2col(
&mut model,
"im2col-adhoc",
inputs[0].datum_type(),
&wires,
)?
} else {
self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
}
};
model.set_output_outlets(&[wire])?;
model.into_runnable()?.run(inputs)
}
}
impl TypedOp for ConvUnary {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let q_inputs = self.q_params.as_ref().map(|(_, qp)| qp.input_count()).unwrap_or(0);
if inputs.len() != 1 + q_inputs {
bail!("Wrong number of inputs: expected {} got {}", 1 + q_inputs, inputs.len());
}
if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
!= &self.input_channels().to_dim()
{
bail!(
"Inconsistent convolution: input is {:?}, kernel expects {} input channels, {:?}",
inputs[0],
self.input_channels(),
self
);
}
if self.pool_spec.output_channel_override != Some(self.output_channels()) {
bail!(
"Inconsistent convolution: output channels from pool spec is {:?}, kernel expects {} output channels, {:?}",
self.pool_spec.output_channel_override,
self.output_channels(),
self
);
}
if let Some(bias) = &self.bias {
ensure!(
bias.rank() == 0 || (bias.rank() == 1 && bias.len() == self.output_channels()),
"Bias should be scalar or a vector with one value per output channel, got:{:?}",
bias
);
}
let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
if let Some((dt, _qp)) = self.q_params.as_ref() {
fact.datum_type = *dt;
} else {
ensure!(
inputs[0].datum_type == self.kernel.datum_type(),
"Convolution input and weights must have the same type. (resp {:?} and {:?})",
inputs[0].datum_type,
self.kernel.datum_type(),
)
}
Ok(tvec!(fact))
}
fn invariants(
&self,
inputs: &[&TypedFact],
_outputs: &[&TypedFact],
) -> TractResult<Invariants> {
let fact = &inputs[0];
let shape = self.pool_spec.data_format.shape(fact.shape.iter().collect::<Vec<TDim>>())?;
let mut axes = vec![];
if let Some(n_axis) = shape.n_axis() {
let mut info = AxisInfo::simple(n_axis).disposable(true);
info.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
axes.push(info);
}
let kernel_spatial_shape =
&self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
let h_axis = shape.h_axis();
for (ix, &dim) in kernel_spatial_shape.iter().enumerate() {
if dim == 1 && self.pool_spec.stride(ix) == 1 {
let mut info = AxisInfo::simple(ix + h_axis).disposable(true);
info.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
axes.push(info)
}
}
Ok(axes.into_iter().collect())
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some((_, qp)) = self.q_params.as_ref() {
if let Some((inputs, qp)) = qp.inline_static(model, node)? {
let mut op = self.clone();
op.q_params.as_mut().unwrap().1 = qp;
let patch = TypedModelPatch::replace_single_op(model, node, &inputs, op)?
.with_context("inlining quantized conv params");
return Ok(Some(patch));
}
}
for d in &[Self::declutter_stride_slice_to_downsample, Self::declutter_as_matmul] {
if let Some(p) = d(self, model, node)? {
return Ok(Some(p));
}
}
if let Some(p) = self.declutter_precursor_padding(model, node)? {
return Ok(Some(p));
}
Ok(None)
}
fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
let kernel_spatial_shape =
&self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
let output_dims = self.pool_spec.padding.compute(
shape.hw_dims(),
kernel_spatial_shape,
&self
.pool_spec
.dilations
.clone()
.unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
&self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
);
let n_output_points: TDim =
output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
let n_output_channels = self.output_channels().to_dim();
let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
let one = 1.to_dim();
Ok(tvec!(
(
Cost::Params(inputs[0].datum_type.unquantized()),
(self.kernel.len() + self.bias.as_ref().map(|b| b.len()).unwrap_or(0)).to_dim()
),
(
Cost::FMA(inputs[0].datum_type),
shape.n().cloned().unwrap_or(one)
* shape.c()
* n_output_channels
* n_output_points
* kernel_surface
/ self.group
)
))
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
_io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
// remove n
if let Some(n) = shape.n_axis() {
assert_eq!(n, 0);
if change == &AxisOp::Rm(n) {
let op = ConvUnary { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
return Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(op)),
change,
)));
}
if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
return Ok(None);
}
}
// format swap: chw <-> hwc
let (new_format, axis_move) = match self.pool_spec.data_format {
DataFormat::NCHW => {
(DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
}
DataFormat::CHW => {
(DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
}
DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
};
if *change == axis_move {
let mut new_op = self.clone();
new_op.pool_spec.data_format = new_format;
return Ok(Some(AxisChangeConsequence {
substitute_op: Some(Box::new(new_op)),
wire_changes: tvec!(
(InOut::In(0), change.clone()),
(InOut::Out(0), change.clone())
),
}));
}
// geo axis manips
use AxisOp::*;
let h_axis = shape.h_axis();
let hw_axes = shape.hw_axes();
let kh_axis = if self.kernel_fmt == KernelFormat::OIHW { 2 } else { 0 };
let (geo_adjusted, kernel_adjusted) = match change {
Rm(a)
if hw_axes.contains(a)
&& self.pool_spec.dilation(a - h_axis) == 1
&& self.pool_spec.stride(a - h_axis) == 1
&& self.pool_spec.kernel_shape[a - h_axis] == 1 =>
{
(Rm(a - h_axis), Rm(a - h_axis + kh_axis))
}
Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
(Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
}
_ => return Ok(None),
};
let mut kernel = self.kernel.clone().into_tensor();
kernel_adjusted.change_tensor(&mut kernel, false)?;
let mut dilations = self.pool_spec.dilations().into_owned().into();
geo_adjusted.change_shape_array(&mut dilations, false)?;
let mut kernel_shape = self.pool_spec.kernel_shape.clone();
geo_adjusted.change_shape_array(&mut kernel_shape, false)?;
let mut strides = self.pool_spec.strides().into_owned().into();
geo_adjusted.change_shape_array(&mut strides, false)?;
let new_op = ConvUnary {
pool_spec: PoolSpec {
data_format: self.pool_spec.data_format,
padding: self.pool_spec.padding.clone(), // fixme (explicit padding)
dilations: Some(dilations),
kernel_shape,
strides: Some(strides),
output_channel_override: self.pool_spec.output_channel_override,
},
kernel_fmt: self.kernel_fmt,
kernel: kernel.into_arc_tensor(),
group: self.group,
bias: self.bias.clone(),
q_params: self.q_params.clone(),
};
Ok(Some(AxisChangeConsequence {
substitute_op: Some(Box::new(new_op)),
wire_changes: tvec!((InOut::In(0), change.clone()), (InOut::Out(0), change.clone())),
}))
}
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let DatumType::U8 = self.kernel.datum_type().unquantized() {
let mut patch = TypedModelPatch::default();
let mut inputs = node
.inputs
.iter()
.map(|w| patch.tap_model(model, *w))
.collect::<TractResult<TVec<_>>>()?;
let new_op = self.kernel_offset_u8_as_i8(&mut inputs, &mut patch)?.unwrap();
let wire = patch.wire_node(&node.name, new_op, &inputs)?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
patch.obliterate(node.id)?;
return Ok(Some(patch.with_context("kernel-u8-to-i8")));
}
let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
let input_fact = model.outlet_fact(node.inputs[0])?;
let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
let spatial_rank = input_shape.hw_rank();
let kernel_spatial_shape = &self.kernel.shape()[self.kernel_fmt.h_axis()..][..spatial_rank];
unsafe {
let dt = input_fact.datum_type;
if self.q_params.is_some() {
let mut patch = TypedModelPatch::default();
let inputs = node
.inputs
.iter()
.map(|w| patch.tap_model(model, *w))
.collect::<TractResult<TVec<_>>>()?;
let wire = self.wire_as_quant_im2col(
&mut patch,
&node.name,
model.node_input_facts(node.id)?[0].datum_type,
&inputs,
)?;
patch.shunt_outside(model, node.id.into(), wire)?;
patch.obliterate(node.id)?;
Ok(Some(patch.with_context("quantized-codegen")))
} else if kernel_spatial_shape.iter().product::<usize>() == 1
&& (0..spatial_rank)
.all(|i| self.pool_spec.stride(i) == 1 && self.pool_spec.dilation(i) == 1)
&& self.group == 1
{
use crate::ops::matmul::MatMulUnary;
let mut patch = TypedModelPatch::default();
let mut wire = patch.tap_model(model, node.inputs[0])?;
let input_c_is_last = input_shape.c_axis() == input_shape.rank() - 1;
let geo_dim: TDim = input_shape.hw_dims().iter().product();
wire = patch.wire_node(
format!("{}.reshape_input", &*node.name),
AxisOp::Reshape(
input_shape.h_axis(),
input_shape.hw_dims().into(),
tvec!(geo_dim.clone()),
),
&[wire],
)?[0];
let kernel_shape = match self.kernel_fmt {
KernelFormat::HWIO => &self.kernel.shape()[spatial_rank..],
KernelFormat::OIHW => &self.kernel.shape()[..2],
};
let operating_rank = input_fact.rank() + 1 - kernel_spatial_shape.len();
let kernel = self
.kernel
.as_ref()
.clone()
.into_shape(kernel_shape)?
.broadcast_into_rank(operating_rank)?;
wire = patch.wire_node(
&format!("{}.matmul", &node.name),
MatMulUnary::new(
kernel.into_arc_tensor(),
MatMulAxes::default_for_rank(operating_rank).transposing(
self.kernel_fmt == KernelFormat::HWIO,
input_c_is_last,
input_c_is_last,
),
),
&[wire],
)?[0];
if let Some(ref bias) = self.bias {
let bias_shape =
if input_c_is_last { [1, bias.len()] } else { [bias.len(), 1] };
let bias = bias
.clone()
.into_tensor()
.into_shape(&bias_shape)?
.broadcast_into_rank(operating_rank)?
.into_arc_tensor();
let bias = patch.add_const(format!("{}.bias.cst", node.name), bias)?;
wire = patch.wire_node(
format!("{}.bias", node.name),
crate::ops::math::add(),
&[wire, bias],
)?[0];
}
wire = patch.wire_node(
&*node.name,
AxisOp::Reshape(
input_shape.h_axis(),
tvec!(geo_dim),
input_shape.hw_dims().into(),
),
&[wire],
)?[0];
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
patch.obliterate(node.id)?;
Ok(Some(patch))
} else if input_fact
.shape
.as_concrete()
.map(|s| {
should_use_lazy(
&self.pool_spec.data_format.shape(s.into()).unwrap(),
&self.pool_spec,
self.group,
)
})
.unwrap_or(false)
{
let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
let mut wire = patch.tap_model(model, node.inputs[0])?;
wire = self.wire_as_lazy_im2col(&mut patch, &node.name, wire)?;
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
patch.obliterate(node.id)?;
Ok(Some(patch))
} else if self.group != 1
&& self.group == self.output_channels()
&& self.group == self.input_channels()
&& input_fact.shape.as_concrete().is_some()
{
let op = dispatch_floatlike!(Self::to_depth_wise(dt)(self, input_fact))
.context("in to_depth_wise")?;
Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?))
} else {
let mut patch = TypedModelPatch::default();
let wire = patch.tap_model(model, node.inputs[0])?;
let wire = self
.wire_as_im2col_pair(&mut patch, &node.name, wire)
.context("in wire_as_im2col_pair")?;
patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
patch.obliterate(node.id)?;
Ok(Some(patch))
}
}
}source§impl<F, O> Graph<F, O>where
F: Fact + Clone + 'static + Hash + for<'a> From<&'a F>,
O: Display + Debug + Clone + AsRef<dyn Op> + AsMut<dyn Op> + 'static + Hash + for<'a> From<&'a O>,
Graph<F, O>: SpecialOps<F, O>,
impl<F, O> Graph<F, O>where
F: Fact + Clone + 'static + Hash + for<'a> From<&'a F>,
O: Display + Debug + Clone + AsRef<dyn Op> + AsMut<dyn Op> + 'static + Hash + for<'a> From<&'a O>,
Graph<F, O>: SpecialOps<F, O>,
sourcepub fn check_compact(&self) -> TractResult<()>
pub fn check_compact(&self) -> TractResult<()>
sourcepub fn compact(&mut self) -> TractResult<()>
pub fn compact(&mut self) -> TractResult<()>
Examples found in repository?
More examples
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
model.check_consistency().context("during optimizer preflight check")?;
model.compact().context("during optimizer preflight compaction")?;
for i in 0.. {
let old = self.counter;
self.run_all_passes(i, model)?;
if old == self.counter {
return Ok(());
}
model.compact()?;
}
unreachable!()
}
pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
let mut passes = self.optimizer.passes.clone();
for p in passes.iter_mut() {
self.run_one_pass_outer(i, p.as_mut(), model)
.with_context(|| format!("running pass {:?}", p))?;
model.compact()?;
model
.check_consistency()
.with_context(|| format!("consistency check after pass {:?}", p))?;
}
Ok(())
}
pub fn run_one_pass_outer(
&mut self,
i: usize,
p: &mut dyn TypedPass,
model: &mut TypedModel,
) -> TractResult<()> {
loop {
let old_counter = self.counter;
self.run_one_pass_inner(i, p, model)?;
if self.counter == old_counter {
return Ok(());
}
model.compact().with_context(|| format!("after pass {:?}", p))?;
}
}530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613
fn try_body_axes_change(
&self,
change: AxisChange,
locked_interface: bool,
) -> TractResult<Option<AxisChangeConsequence>> {
self.body.check_consistency()?;
let interface = self.body_exposed_outlets()?;
let (patch, body_changed_wires) = if let Some(changes) =
crate::ops::change_axes::change_axes(
&self.body,
&change,
if locked_interface { &interface } else { &[] },
&self.body_bounds()?,
)? {
changes
} else {
return Ok(None);
};
let mut body = self.body.clone();
patch.apply(&mut body)?;
body.compact()?;
let mut wire_changes = tvec!();
let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
for (ix, m) in input_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::In(ix))
.map(|pair| pair.1.clone())
{
if let Some(slot) = m.slot() {
wire_changes.push((InOut::In(slot), change.clone()));
}
match &*m {
InputMapping::Full { .. } => (),
&InputMapping::Scan(info) => {
if let Some(axis) = change.transform_axis(info.axis) {
*m = InputMapping::Scan(ScanInfo { axis, ..info });
} else {
return Ok(None);
};
}
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(_) => (),
StateInitializer::Value(ref v) => {
let mut v = v.clone().into_tensor();
change.change_tensor(&mut v, false)?;
*m = InputMapping::State {
initializer: StateInitializer::Value(v.into_arc_tensor()),
};
}
},
};
}
}
let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
for (ix, m) in output_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::Out(ix))
.map(|pair| pair.1.clone())
{
if let Some(info) = m.scan.as_mut() {
if let Some(new_axis) = change.transform_axis(info.axis) {
info.axis = new_axis;
} else {
return Ok(None);
}
wire_changes.push((InOut::Out(info.slot), change.clone()));
}
if let Some(slot) = m.last_value_slot {
wire_changes.push((InOut::Out(slot), change.clone()));
}
};
}
body.check_consistency()?;
let op = Some(Box::new(Scan {
body,
input_mapping,
output_mapping,
decluttered: false,
..self.clone()
}) as _);
Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
}pub fn into_compact(self) -> TractResult<Self>
source§impl Graph<TypedFact, Box<dyn TypedOp + 'static, Global>>
impl Graph<TypedFact, Box<dyn TypedOp + 'static, Global>>
pub fn signature(&self) -> u64
sourcepub fn into_optimized(self) -> TractResult<TypedModel>
pub fn into_optimized(self) -> TractResult<TypedModel>
Examples found in repository?
21 22 23 24 25 26 27 28 29 30 31 32 33 34
pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<LirScan> {
let mut model = self.body.clone();
if optimize_inner {
model = model.into_optimized()?;
}
let plan = SimplePlan::new(model)?;
Ok(LirScan::new(Arc::new(LirScanOpParams::new(
self.skip,
Arc::new(plan),
self.input_mapping.clone(),
self.output_mapping.clone(),
))))
}sourcepub fn check_consistency(&self) -> TractResult<()>
pub fn check_consistency(&self) -> TractResult<()>
Examples found in repository?
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
model.check_consistency().context("during optimizer preflight check")?;
model.compact().context("during optimizer preflight compaction")?;
for i in 0.. {
let old = self.counter;
self.run_all_passes(i, model)?;
if old == self.counter {
return Ok(());
}
model.compact()?;
}
unreachable!()
}
pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
let mut passes = self.optimizer.passes.clone();
for p in passes.iter_mut() {
self.run_one_pass_outer(i, p.as_mut(), model)
.with_context(|| format!("running pass {:?}", p))?;
model.compact()?;
model
.check_consistency()
.with_context(|| format!("consistency check after pass {:?}", p))?;
}
Ok(())
}
pub fn run_one_pass_outer(
&mut self,
i: usize,
p: &mut dyn TypedPass,
model: &mut TypedModel,
) -> TractResult<()> {
loop {
let old_counter = self.counter;
self.run_one_pass_inner(i, p, model)?;
if self.counter == old_counter {
return Ok(());
}
model.compact().with_context(|| format!("after pass {:?}", p))?;
}
}
pub fn run_one_pass_inner(
&mut self,
i: usize,
p: &mut dyn TypedPass,
model: &mut TypedModel,
) -> TractResult<()> {
p.reset()?;
if let Some(steps) = self.optimizer.steps {
if self.counter >= steps {
return Ok(());
}
}
while let Some(mut patch) = p.next(self, model)? {
patch.push_context(format!("{:?}/{}", p, i));
patch.model.check_consistency().context("checking patch internal consistency")?;
model
.check_consistency()
.context("Checking target model consistency before patching")?;
if let Some(watchdog) = patch.dont_apply_twice.take() {
if self.seen.contains(&watchdog) {
debug!("Loop detected: {} seen before", watchdog);
continue;
} else {
self.seen.insert(watchdog);
}
}
debug!("applying patch #{}: {}", self.counter, patch.context.iter().rev().join(" >> "),);
patch.apply(model)?;
model
.check_consistency()
.context("Checking target model consistency after patchign")?;
self.counter += 1;
if let Some(steps) = self.optimizer.steps {
if self.counter >= steps {
return Ok(());
}
}
}
model.check_consistency().with_context(|| format!("after pass {:?}", p))?;
Ok(())
}More examples
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613
pub fn new(
body: TypedModel,
input_mapping: Vec<InputMapping>,
output_mapping: Vec<OutputMapping<TDim>>,
seq_length_input_slot: Option<usize>,
skip: usize,
) -> TractResult<Scan> {
body.check_consistency()?;
ensure!(input_mapping.len() == body.input_outlets()?.len());
ensure!(output_mapping.len() == body.output_outlets()?.len());
Ok(Scan {
skip,
body,
decluttered: false,
input_mapping,
output_mapping,
seq_length_input_slot,
})
}
pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
self.to_codegen_op(false).unwrap().iteration_count(inputs)
}
fn declutter_body(
&self,
session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if !self.decluttered {
let mut new = self.clone();
let mut body = self.body.clone();
session.optimize(&mut body)?;
new.body = body;
new.decluttered = true;
Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
} else {
Ok(None)
}
}
fn declutter_body_axes(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut suggestions = vec![];
for n in self.body.eval_order()? {
let node = self.body.node(n);
for suggestion in node.op.suggested_axis_changes()? {
let outlet = suggestion.0.as_outlet(node);
suggestions.push(AxisChange { outlet, op: suggestion.1 })
}
}
for suggestion in suggestions.into_iter() {
if let Some(op) =
self.try_body_axes_change(suggestion, true)?.and_then(|c| c.substitute_op)
{
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
op,
)?));
}
}
Ok(None)
}
fn remove_outer_input_from_mappings(
mappings: &[InputMapping],
discarded: usize,
) -> Vec<InputMapping> {
mappings
.iter()
.map(|m| match m {
&InputMapping::Full { slot } => {
InputMapping::Full { slot: slot - (slot > discarded) as usize }
}
&InputMapping::Scan(info) => InputMapping::Scan(ScanInfo {
slot: info.slot - (info.slot > discarded) as usize,
..info
}),
InputMapping::State { initializer } => {
let initializer = match initializer {
StateInitializer::FromInput(n) => {
StateInitializer::FromInput(*n - (*n > discarded) as usize)
}
StateInitializer::Value(v) => StateInitializer::Value(v.clone()),
};
InputMapping::State { initializer }
}
})
.collect()
}
fn remove_outer_output_from_mappings(
mappings: &[OutputMapping<TDim>],
discarded: usize,
) -> Vec<OutputMapping<TDim>> {
mappings
.iter()
.map(|m| OutputMapping {
scan: m.scan.map(|info| ScanInfo {
slot: info.slot - (info.slot > discarded) as usize,
..info
}),
last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
full_dim_hint: m.full_dim_hint.clone(),
state: m.state,
})
.collect()
}
fn declutter_const_initializer(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
for (ix, mapping) in self.input_mapping.iter().enumerate() {
if let InputMapping::State { initializer: StateInitializer::FromInput(n) } = mapping {
if let Some(i) = inputs[*n].konst.as_ref() {
let mut op = self.clone();
op.input_mapping[ix] =
InputMapping::State { initializer: StateInitializer::Value(i.clone()) };
op.input_mapping =
Self::remove_outer_input_from_mappings(&op.input_mapping, *n);
let mut inputs = node.inputs.clone();
inputs.remove(*n);
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
}
}
}
Ok(None)
}
fn declutter_discard_unused_input_mapping(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
let source_node = self.body.node(input.node);
if source_node.outputs[0].successors.len() == 0
&& !self.body.output_outlets()?.contains(input)
{
let mut new_inputs = node.inputs.clone();
let slot = match &self.input_mapping[inner_input_id] {
InputMapping::Full { slot } => Some(*slot),
InputMapping::Scan(info) => Some(info.slot),
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(n) => Some(*n),
_ => None,
},
};
let mut new_mappings: Vec<_> = self.input_mapping.clone();
new_mappings.remove(inner_input_id);
if let Some(slot) = slot {
new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
}
let mut model_inputs = self.body.input_outlets()?.to_vec();
if let Some(slot) = slot {
new_inputs.remove(slot);
}
model_inputs.remove(inner_input_id);
let mut body = self.body.clone();
let mut patch = TypedModelPatch::default();
patch.obliterate(source_node.id)?;
patch.apply(&mut body)?;
body.set_input_outlets(&model_inputs)?;
body.declutter()?;
let op = Self {
body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
input_mapping: new_mappings,
decluttered: true,
output_mapping: self.output_mapping.clone(),
};
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
}
}
Ok(None)
}
fn declutter_discard_useless_outer_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, o) in node.outputs.iter().enumerate() {
if o.successors.len() == 0
&& !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
{
let mappings = self
.output_mapping
.iter()
.map(|m| OutputMapping {
scan: m.scan.filter(|info| info.slot != ix),
last_value_slot: m.last_value_slot.filter(|s| *s != ix),
full_dim_hint: m.full_dim_hint.clone(),
state: m.state,
})
.collect::<Vec<_>>();
let mut op = self.clone();
op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
let mut patch = TypedModelPatch::default();
let inputs = node
.inputs
.iter()
.map(|&i| patch.tap_model(model, i))
.collect::<TractResult<Vec<_>>>()?;
let wires = patch.wire_node(&*node.name, op, &inputs)?;
for oix in 0..node.outputs.len() {
if oix != ix {
patch.shunt_outside(
model,
OutletId::new(node.id, oix),
wires[oix - (oix > ix) as usize],
)?;
}
}
return Ok(Some(patch));
}
}
Ok(None)
}
fn declutter_discard_empty_output_mapping_with_body_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, om) in self.output_mapping.iter().enumerate() {
if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
let mut new_op = self.clone();
new_op.output_mapping.remove(ix);
new_op.body.outputs.remove(ix);
new_op.decluttered = false;
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
new_op,
)?));
}
}
Ok(None)
}
fn declutter_pull_batcheable_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_input, input) in self.input_mapping.iter().enumerate() {
if let Some(info) = input.as_scan() {
let scan_source = self.body.input_outlets()?[model_input];
let scan_source_node = self.body.node(scan_source.node);
for successor in &scan_source_node.outputs[0].successors {
let successor_node = self.body.node(successor.node);
if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
continue;
}
let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
let mut outside_patch = TypedModelPatch::new(format!(
"Outer patch for input extraction of {}",
successor_node
));
let mut patch_inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let input = patch_inputs[info.slot];
let new_input_wire = outside_patch.wire_node(
format!("{}.extracted.{}", node.name, successor_node.name),
successor_node.op.clone(),
&[input],
)?[0];
patch_inputs.push(new_input_wire);
let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
let mut new_input_inner_fact = new_input_outer_fact.clone();
new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());
let mut new_body = self.body.clone();
let new_source_wire = new_body.add_source(
format!("{}.extracted.{}", node.name, successor_node.name),
new_input_inner_fact,
)?;
let mut inner_patch = TypedModelPatch::new(format!(
"Inner body patch for extraction of {}",
successor_node
));
let new_source_wire_in_patch =
inner_patch.tap_model(&new_body, new_source_wire)?;
inner_patch
.shunt_outside(
&new_body,
OutletId::new(successor.node, 0),
new_source_wire_in_patch,
)
.with_context(|| "patching inner model")?;
inner_patch.apply(&mut new_body)?;
let mut input_mapping = self.input_mapping.clone();
input_mapping.push(InputMapping::Scan(ScanInfo {
axis: axis_after,
chunk: info.chunk,
slot: node.inputs.len(),
}));
let new_op = Self {
input_mapping,
output_mapping: self.output_mapping.clone(),
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let output_wires =
outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
for w in output_wires {
outside_patch
.shunt_outside(model, OutletId::new(node.id, w.slot), w)
.with_context(|| "patching outer model")?;
}
return Ok(Some(outside_patch));
}
}
}
}
Ok(None)
}
fn declutter_pull_constant_outputs(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(slot) = mapping.last_value_slot {
if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
let inner_node = self.body.output_outlets()?[model_output_ix].node;
let inner_node = self.body.node(inner_node);
let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}
fn declutter_pull_batcheable_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(info) = mapping.scan {
let emitter_outlet = self.body.output_outlets()?[model_ix];
let emitter_node = self.body.node(emitter_outlet.node);
if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
|| mapping.state
|| mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
{
// continue if both last_value and full values are exported
continue;
}
let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
else {
continue;
};
let mut new_body = self.body.clone();
let mut new_output_mapping = self.output_mapping.clone();
let mut new_scan_outputs = node.outputs.len();
let mut outer_slots = vec![];
for input in &emitter_node.inputs {
if new_body.outputs.iter().all(|o| o != input) {
new_output_mapping.push(OutputMapping::default());
new_body.outputs.push(*input);
}
let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
let mut mapping = &mut new_output_mapping[body_output_id];
let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
if mapping.last_value_slot.is_none() {
mapping.last_value_slot = Some(new_scan_outputs);
}
new_scan_outputs += 1;
mapping.last_value_slot.unwrap()
} else {
if mapping.scan.is_none() {
mapping.scan = Some(ScanInfo {
slot: new_scan_outputs,
axis: axis_before,
chunk: info.chunk,
});
new_scan_outputs += 1;
}
mapping.scan.unwrap().slot
};
outer_slots.push(outer_slot);
}
let mut outside_patch = TypedModelPatch::new(format!(
"Outside patch for output extraction of {}",
emitter_node
));
let inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let new_op = Self {
input_mapping: self.input_mapping.clone(),
output_mapping: new_output_mapping,
decluttered: false,
body: new_body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
};
let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
let output = mapping.scan.unwrap();
let inputs =
outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
let wire = outside_patch.wire_node(
&*emitter_node.name,
emitter_node.op.clone(),
&inputs,
)?[0];
outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
for output_slot in 0..node.outputs.len() {
if output_slot != output.slot {
outside_patch.shunt_outside(
model,
OutletId::new(node.id, output_slot),
OutletId::new(scan_outputs[0].node, output_slot),
)?;
}
}
return Ok(Some(outside_patch));
}
}
Ok(None)
}
fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
let input_state_outlets = self
.input_mapping
.iter()
.zip(self.body.input_outlets()?.iter())
.filter(|(m, _)| m.as_state().is_some())
.map(|(_, o)| o);
let output_state_outlets = self
.output_mapping
.iter()
.zip(self.body.output_outlets()?.iter())
.filter(|(m, _)| m.state)
.map(|(_, o)| o);
Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
}
fn body_exposed_outlets(&self) -> TractResult<TVec<OutletId>> {
let input_outlets = self
.input_mapping
.iter()
.zip(self.body.input_outlets()?.iter())
.filter(|(m, _)| !m.invisible())
.map(|(_, o)| o);
let output_outlets = self
.output_mapping
.iter()
.zip(self.body.output_outlets()?.iter())
.filter(|(m, _)| !m.invisible())
.map(|(_, o)| o);
Ok(input_outlets.chain(output_outlets).cloned().collect())
}
fn try_body_axes_change(
&self,
change: AxisChange,
locked_interface: bool,
) -> TractResult<Option<AxisChangeConsequence>> {
self.body.check_consistency()?;
let interface = self.body_exposed_outlets()?;
let (patch, body_changed_wires) = if let Some(changes) =
crate::ops::change_axes::change_axes(
&self.body,
&change,
if locked_interface { &interface } else { &[] },
&self.body_bounds()?,
)? {
changes
} else {
return Ok(None);
};
let mut body = self.body.clone();
patch.apply(&mut body)?;
body.compact()?;
let mut wire_changes = tvec!();
let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
for (ix, m) in input_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::In(ix))
.map(|pair| pair.1.clone())
{
if let Some(slot) = m.slot() {
wire_changes.push((InOut::In(slot), change.clone()));
}
match &*m {
InputMapping::Full { .. } => (),
&InputMapping::Scan(info) => {
if let Some(axis) = change.transform_axis(info.axis) {
*m = InputMapping::Scan(ScanInfo { axis, ..info });
} else {
return Ok(None);
};
}
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(_) => (),
StateInitializer::Value(ref v) => {
let mut v = v.clone().into_tensor();
change.change_tensor(&mut v, false)?;
*m = InputMapping::State {
initializer: StateInitializer::Value(v.into_arc_tensor()),
};
}
},
};
}
}
let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
for (ix, m) in output_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::Out(ix))
.map(|pair| pair.1.clone())
{
if let Some(info) = m.scan.as_mut() {
if let Some(new_axis) = change.transform_axis(info.axis) {
info.axis = new_axis;
} else {
return Ok(None);
}
wire_changes.push((InOut::Out(info.slot), change.clone()));
}
if let Some(slot) = m.last_value_slot {
wire_changes.push((InOut::Out(slot), change.clone()));
}
};
}
body.check_consistency()?;
let op = Some(Box::new(Scan {
body,
input_mapping,
output_mapping,
decluttered: false,
..self.clone()
}) as _);
Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
}109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
fn pull_downsample_up(
model: &TypedModel,
down_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
model.check_consistency()?;
let down_op = down_node.op_as::<Downsample>().unwrap();
if let Some(prec) = model.single_prec(down_node.id)? {
let (input_facts, output_facts) = model.node_facts(prec.id)?;
let invariants = prec.op.invariants(&input_facts, &output_facts)?;
debug!("Consider pull {:?} over {:?} (invariants: {:?})", down_op, prec, invariants);
if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
if let Some(p) = array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)? {
return Ok(Some(p))
}
} else if let Some(other_op) = prec.op_as::<AxisOp>() {
return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
} else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::ConvUnary>() {
return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
} else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
}
if let Some(above_axis) = invariants.unary_track_axis_up(down_op.axis, false) {
let mut patch = TypedModelPatch::default();
let mut inputs = vec![];
for (ix, &oo) in prec.inputs.iter().enumerate() {
let source = patch.tap_model(model, oo)?;
let mut op = down_op.clone();
op.axis = above_axis;
let ds = patch.wire_node(
format!("{}.{}-{}", down_node.name, prec.name, ix),
op,
[source].as_ref(),
)?;
inputs.push(ds[0]);
}
let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
return Ok(Some(patch));
}
}
Ok(None)
}6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
pub fn pull_downsample_over_scan(
model: &TypedModel,
scan_node: &TypedNode,
scan_op: &ops::scan::Scan,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.stride < 0 {
return Ok(None);
}
// introduce downsample at end of body
let mut downsampled_body = scan_op.body.clone();
downsampled_body.check_consistency()?;
let outputs = downsampled_body.output_outlets()?.to_owned();
let downsample_outputs = outputs
.into_iter()
.enumerate()
.map(|(ix, oo)| {
Ok(downsampled_body.wire_node(
format!("{}-{}", &down_node.name, ix),
down_op.clone(),
&[oo],
)?[0])
})
.collect::<TractResult<Vec<_>>>()?;
downsampled_body.set_output_outlets(&downsample_outputs)?;
downsampled_body.declutter()?;
downsampled_body.check_consistency()?;
// check if downsample ops introduced at end have swimmed up to scan inputs during declutter
for input in downsampled_body.input_outlets()? {
let input = downsampled_body.node(input.node);
if input.outputs[0]
.successors
.iter()
.any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
{
return Ok(None);
}
}
let inputs = downsampled_body.input_outlets()?.to_vec();
for input in inputs {
let node = &mut downsampled_body.node_mut(input.node);
let fact = &mut node.outputs[0].fact;
*fact = down_op.transform_fact(fact)?;
node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
for ds in downsamples {
TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
.apply(&mut downsampled_body)?;
}
}
downsampled_body.check_consistency()?;
let inner_model = downsampled_body.into_decluttered()?;
let mut new_scan = scan_op.clone();
new_scan.body = inner_model;
for input in &mut new_scan.input_mapping {
match input {
InputMapping::State { ref mut initializer } => {
if let StateInitializer::Value(ref v) = initializer {
let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
*initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
}
}
InputMapping::Scan(info) => {
if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
_ => (),
}
}
for output in &mut new_scan.output_mapping {
if let Some(d) = output.full_dim_hint.as_mut() {
*d = down_op.transform_dim(d)
}
if let Some(info) = &mut output.scan {
if info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
}
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for (ix, &i) in scan_node.inputs.iter().enumerate() {
let tap = patch.tap_model(model, i)?;
let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
inputs.push(ds);
}
let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
for ix in 0..scan_node.outputs.len() {
// FIXME need to check earlier on that all output are followed by a ds
let succ = scan_node.outputs[ix].successors[0].node;
patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
}
Ok(Some(patch))
}sourcepub fn into_decluttered(self) -> TractResult<TypedModel>
pub fn into_decluttered(self) -> TractResult<TypedModel>
Examples found in repository?
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
pub fn pull_downsample_over_scan(
model: &TypedModel,
scan_node: &TypedNode,
scan_op: &ops::scan::Scan,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.stride < 0 {
return Ok(None);
}
// introduce downsample at end of body
let mut downsampled_body = scan_op.body.clone();
downsampled_body.check_consistency()?;
let outputs = downsampled_body.output_outlets()?.to_owned();
let downsample_outputs = outputs
.into_iter()
.enumerate()
.map(|(ix, oo)| {
Ok(downsampled_body.wire_node(
format!("{}-{}", &down_node.name, ix),
down_op.clone(),
&[oo],
)?[0])
})
.collect::<TractResult<Vec<_>>>()?;
downsampled_body.set_output_outlets(&downsample_outputs)?;
downsampled_body.declutter()?;
downsampled_body.check_consistency()?;
// check if downsample ops introduced at end have swimmed up to scan inputs during declutter
for input in downsampled_body.input_outlets()? {
let input = downsampled_body.node(input.node);
if input.outputs[0]
.successors
.iter()
.any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
{
return Ok(None);
}
}
let inputs = downsampled_body.input_outlets()?.to_vec();
for input in inputs {
let node = &mut downsampled_body.node_mut(input.node);
let fact = &mut node.outputs[0].fact;
*fact = down_op.transform_fact(fact)?;
node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
for ds in downsamples {
TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
.apply(&mut downsampled_body)?;
}
}
downsampled_body.check_consistency()?;
let inner_model = downsampled_body.into_decluttered()?;
let mut new_scan = scan_op.clone();
new_scan.body = inner_model;
for input in &mut new_scan.input_mapping {
match input {
InputMapping::State { ref mut initializer } => {
if let StateInitializer::Value(ref v) = initializer {
let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
*initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
}
}
InputMapping::Scan(info) => {
if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
_ => (),
}
}
for output in &mut new_scan.output_mapping {
if let Some(d) = output.full_dim_hint.as_mut() {
*d = down_op.transform_dim(d)
}
if let Some(info) = &mut output.scan {
if info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
}
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for (ix, &i) in scan_node.inputs.iter().enumerate() {
let tap = patch.tap_model(model, i)?;
let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
inputs.push(ds);
}
let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
for ix in 0..scan_node.outputs.len() {
// FIXME need to check earlier on that all output are followed by a ds
let succ = scan_node.outputs[ix].successors[0].node;
patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
}
Ok(Some(patch))
}sourcepub fn declutter(&mut self) -> TractResult<()>
pub fn declutter(&mut self) -> TractResult<()>
Perform declutter passes on the network.
Examples found in repository?
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
pub fn into_optimized(mut self) -> TractResult<TypedModel> {
self.declutter()?;
self.optimize()?;
Ok(self)
}
#[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
#[inline]
pub fn check_consistency(&self) -> TractResult<()> {
Ok(())
}
#[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
pub fn check_consistency(&self) -> TractResult<()> {
self.check_edges()?;
for node_id in &self.eval_order()? {
let input_facts = self.node_input_facts(*node_id)?;
let node = &self.nodes[*node_id];
if node.id != *node_id {
bail!("Node at position {} has id {}", node_id, node.id);
}
let output_facts = node.op.output_facts(&input_facts)?;
if node.outputs.len() != output_facts.len() {
bail!(
"Inconsistent model, node output count mismatch. Op says {}, node says {}. {}",
output_facts.len(),
node.outputs.len(),
node
);
}
if node
.outputs
.iter()
.map(|o| &o.fact)
.zip(output_facts.iter())
.any(|(a, b)| a.datum_type != b.datum_type || a.shape != b.shape)
{
bail!(
"Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
}
}
for node in &self.nodes {
for (ix, output) in node.outputs.iter().enumerate() {
output.fact.consistent().with_context(|| {
format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
})?
}
}
Ok(())
}
pub fn into_decluttered(mut self) -> TractResult<TypedModel> {
self.declutter()?;
Ok(self)
}More examples
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
fn declutter_discard_unused_input_mapping(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
let source_node = self.body.node(input.node);
if source_node.outputs[0].successors.len() == 0
&& !self.body.output_outlets()?.contains(input)
{
let mut new_inputs = node.inputs.clone();
let slot = match &self.input_mapping[inner_input_id] {
InputMapping::Full { slot } => Some(*slot),
InputMapping::Scan(info) => Some(info.slot),
InputMapping::State { initializer } => match initializer {
StateInitializer::FromInput(n) => Some(*n),
_ => None,
},
};
let mut new_mappings: Vec<_> = self.input_mapping.clone();
new_mappings.remove(inner_input_id);
if let Some(slot) = slot {
new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
}
let mut model_inputs = self.body.input_outlets()?.to_vec();
if let Some(slot) = slot {
new_inputs.remove(slot);
}
model_inputs.remove(inner_input_id);
let mut body = self.body.clone();
let mut patch = TypedModelPatch::default();
patch.obliterate(source_node.id)?;
patch.apply(&mut body)?;
body.set_input_outlets(&model_inputs)?;
body.declutter()?;
let op = Self {
body,
skip: self.skip,
seq_length_input_slot: self.seq_length_input_slot,
input_mapping: new_mappings,
decluttered: true,
output_mapping: self.output_mapping.clone(),
};
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
}
}
Ok(None)
}6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
pub fn pull_downsample_over_scan(
model: &TypedModel,
scan_node: &TypedNode,
scan_op: &ops::scan::Scan,
down_node: &TypedNode,
down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
if down_op.stride < 0 {
return Ok(None);
}
// introduce downsample at end of body
let mut downsampled_body = scan_op.body.clone();
downsampled_body.check_consistency()?;
let outputs = downsampled_body.output_outlets()?.to_owned();
let downsample_outputs = outputs
.into_iter()
.enumerate()
.map(|(ix, oo)| {
Ok(downsampled_body.wire_node(
format!("{}-{}", &down_node.name, ix),
down_op.clone(),
&[oo],
)?[0])
})
.collect::<TractResult<Vec<_>>>()?;
downsampled_body.set_output_outlets(&downsample_outputs)?;
downsampled_body.declutter()?;
downsampled_body.check_consistency()?;
// check if downsample ops introduced at end have swimmed up to scan inputs during declutter
for input in downsampled_body.input_outlets()? {
let input = downsampled_body.node(input.node);
if input.outputs[0]
.successors
.iter()
.any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
{
return Ok(None);
}
}
let inputs = downsampled_body.input_outlets()?.to_vec();
for input in inputs {
let node = &mut downsampled_body.node_mut(input.node);
let fact = &mut node.outputs[0].fact;
*fact = down_op.transform_fact(fact)?;
node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
for ds in downsamples {
TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
.apply(&mut downsampled_body)?;
}
}
downsampled_body.check_consistency()?;
let inner_model = downsampled_body.into_decluttered()?;
let mut new_scan = scan_op.clone();
new_scan.body = inner_model;
for input in &mut new_scan.input_mapping {
match input {
InputMapping::State { ref mut initializer } => {
if let StateInitializer::Value(ref v) = initializer {
let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
*initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
}
}
InputMapping::Scan(info) => {
if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
_ => (),
}
}
for output in &mut new_scan.output_mapping {
if let Some(d) = output.full_dim_hint.as_mut() {
*d = down_op.transform_dim(d)
}
if let Some(info) = &mut output.scan {
if info.chunk as usize % down_op.stride as usize != 0 {
return Ok(None);
}
info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
* info.chunk.signum()
}
}
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for (ix, &i) in scan_node.inputs.iter().enumerate() {
let tap = patch.tap_model(model, i)?;
let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
inputs.push(ds);
}
let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
for ix in 0..scan_node.outputs.len() {
// FIXME need to check earlier on that all output are followed by a ds
let succ = scan_node.outputs[ix].successors[0].node;
patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
}
Ok(Some(patch))
}sourcepub fn optimize_with_session(
&mut self,
session: &mut OptimizerSession<'_>
) -> TractResult<()>
pub fn optimize_with_session(
&mut self,
session: &mut OptimizerSession<'_>
) -> TractResult<()>
Perform optimization passes on the model, using a given optimizer session.
sourcepub fn concretize_dims(&self, values: &SymbolValues) -> TractResult<TypedModel>
pub fn concretize_dims(&self, values: &SymbolValues) -> TractResult<TypedModel>
Examples found in repository?
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
fn concretize_dims(
&self,
_source: &TypedModel,
node: &TypedNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
values: &SymbolValues,
) -> TractResult<TVec<OutletId>> {
let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
let op = Self {
output_mapping: self
.output_mapping
.iter()
.map(|om| om.concretize_dims(values))
.collect::<TractResult<Vec<_>>>()?,
body: self.body.concretize_dims(values)?,
..self.clone()
};
target.wire_node(&node.name, op, &inputs)
}sourcepub fn optimize(&mut self) -> TractResult<()>
pub fn optimize(&mut self) -> TractResult<()>
Translate the graph to locally optimized operators (LIR or MIR ops).
sourcepub fn invariants(&self) -> TractResult<Invariants>
pub fn invariants(&self) -> TractResult<Invariants>
Examples found in repository?
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718
fn invariants(
&self,
_inputs: &[&TypedFact],
_outputs: &[&TypedFact],
) -> TractResult<Invariants> {
let mut invariants = tvec!();
let body_invs = self.body.invariants().with_context(|| "Computing body invariants")?;
for body_axis in body_invs.axes {
let mut info = AxisInfo::default().with_period(1);
for (ix, input_mapping) in self.input_mapping.iter().enumerate() {
if let Some(slot) = input_mapping.slot() {
while info.inputs.len() <= slot {
info.inputs.push(None);
}
info.inputs[slot] = body_axis.inputs[ix];
}
}
for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
let mut slots = vec![];
if let Some(scan) = output_mapping.scan {
slots.push(scan.slot);
}
if let Some(slot) = output_mapping.last_value_slot {
slots.push(slot);
}
for slot in slots {
while info.outputs.len() <= slot {
info.outputs.push(None);
}
info.outputs[slot] = body_axis.outputs[ix];
}
}
if info.inputs.iter().any(|i| i.is_some()) || info.outputs.iter().any(|i| i.is_some()) {
info.disposable = body_axis.disposable;
invariants.push(info);
}
}
Ok(Invariants::from(invariants))
}Trait Implementations§
source§impl<F, O> Clone for Graph<F, O>where
F: Fact + Hash + Clone + 'static + Clone,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash + Clone,
impl<F, O> Clone for Graph<F, O>where
F: Fact + Hash + Clone + 'static + Clone,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash + Clone,
source§impl<F, O> Debug for Graph<F, O>where
F: Fact + Hash + Clone + 'static + Debug,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash + Debug,
impl<F, O> Debug for Graph<F, O>where
F: Fact + Hash + Clone + 'static + Debug,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash + Debug,
source§impl<F, O> Default for Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
impl<F, O> Default for Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
source§impl<F, O> Display for Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
impl<F, O> Display for Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
source§impl<F, O> DynHash for Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
impl<F, O> DynHash for Graph<F, O>where
F: Fact + Hash + Clone + 'static,
O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
Auto Trait Implementations§
impl<F, O> RefUnwindSafe for Graph<F, O>where
F: RefUnwindSafe,
O: RefUnwindSafe,
impl<F, O> Send for Graph<F, O>where
O: Send,
impl<F, O> Sync for Graph<F, O>where
O: Sync,
impl<F, O> Unpin for Graph<F, O>where
F: Unpin,
O: Unpin,
impl<F, O> UnwindSafe for Graph<F, O>where
F: UnwindSafe + RefUnwindSafe,
O: UnwindSafe,
Blanket Implementations§
source§impl<T> Downcast for Twhere
T: Any,
impl<T> Downcast for Twhere
T: Any,
source§fn into_any(self: Box<T, Global>) -> Box<dyn Any + 'static, Global>
fn into_any(self: Box<T, Global>) -> Box<dyn Any + 'static, Global>
Box<dyn Trait> (where Trait: Downcast) to Box<dyn Any>. Box<dyn Any> can
then be further downcast into Box<ConcreteType> where ConcreteType implements Trait.source§fn into_any_rc(self: Rc<T>) -> Rc<dyn Any + 'static>
fn into_any_rc(self: Rc<T>) -> Rc<dyn Any + 'static>
Rc<Trait> (where Trait: Downcast) to Rc<Any>. Rc<Any> can then be
further downcast into Rc<ConcreteType> where ConcreteType implements Trait.source§fn as_any(&self) -> &(dyn Any + 'static)
fn as_any(&self) -> &(dyn Any + 'static)
&Trait (where Trait: Downcast) to &Any. This is needed since Rust cannot
generate &Any’s vtable from &Trait’s.source§fn as_any_mut(&mut self) -> &mut (dyn Any + 'static)
fn as_any_mut(&mut self) -> &mut (dyn Any + 'static)
&mut Trait (where Trait: Downcast) to &Any. This is needed since Rust cannot
generate &mut Any’s vtable from &mut Trait’s.