Function tract_libcli::tensor::tensor_for_fact
source · pub fn tensor_for_fact(
fact: &TypedFact,
streaming_dim: Option<usize>,
tv: Option<&TensorValues>
) -> TractResult<Tensor>Examples found in repository?
src/tensor.rs (line 366)
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
pub fn retrieve_or_make_inputs(
tract: &dyn Model,
params: &RunParams,
) -> TractResult<Vec<TVec<TValue>>> {
let mut tmp: TVec<Vec<TValue>> = tvec![];
for (ix, input) in tract.input_outlets().iter().enumerate() {
let name = tract.node_name(input.node);
let fact = tract.outlet_typedfact(*input)?;
if let Some(mut value) = params.tensors_values.by_name(name).and_then(|t| t.values.clone())
{
if !value[0].datum_type().is_quantized()
&& fact.datum_type.is_quantized()
&& value[0].datum_type() == fact.datum_type.unquantized()
{
value = value
.iter()
.map(|v| {
let mut v = v.clone().into_tensor();
unsafe { v.set_datum_type(fact.datum_type) };
v.into()
})
.collect();
}
if TypedFact::from(&*value[0]).compatible_with(&fact) {
info!("Using fixed input for input called {} ({} turn(s))", name, value.len());
tmp.push(value.iter().map(|t| t.clone().into_tensor().into()).collect())
} else if fact.datum_type == f16::datum_type()
&& value[0].datum_type() == f32::datum_type()
&& params.allow_float_casts
{
tmp.push(
value.iter().map(|t| t.cast_to::<f16>().unwrap().into_owned().into()).collect(),
)
} else if value.len() == 1 && tract.properties().contains_key("pulse.delay") {
let value = &value[0];
let input_pulse_axis = tract
.properties()
.get("pulse.input_axes")
.context("Expect pulse.input_axes property")?
.cast_to::<i64>()?
.as_slice::<i64>()?[ix] as usize;
let input_pulse = fact.shape.get(input_pulse_axis).unwrap().to_usize().unwrap();
let input_len = value.shape()[input_pulse_axis];
// how many pulses do we need to push full result out ?
// guess by looking at len and delay of the first output
let output_pulse_axis = tract
.properties()
.get("pulse.output_axes")
.context("Expect pulse.output_axes property")?
.cast_to::<i64>()?
.as_slice::<i64>()?[0] as usize;
let output_fact = tract.outlet_typedfact(tract.output_outlets()[0])?;
let output_pulse =
output_fact.shape.get(output_pulse_axis).unwrap().to_usize().unwrap();
let output_len = input_len * output_pulse / input_pulse;
let output_delay = tract.properties()["pulse.delay"].as_slice::<i64>()?[0] as usize;
let last_frame = output_len + output_delay;
let needed_pulses = last_frame.divceil(output_pulse);
let mut values = vec![];
for ix in 0..needed_pulses {
let mut t =
Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
let start = ix * input_pulse;
let end = (start + input_pulse).min(input_len);
if end > start {
t.assign_slice(0..end - start, value, start..end, input_pulse_axis)?;
}
values.push(t.into());
}
info!(
"Generated {} pulses of shape {:?} for input {}.",
needed_pulses, fact.shape, ix
);
tmp.push(values);
} else {
bail!("For input {}, can not reconcile model input fact {:?} with provided input {:?}", name, fact, value[0]);
};
} else if params.allow_random_input {
let fact = tract.outlet_typedfact(*input)?;
warn_once(format!("Using random input for input called {:?}: {:?}", name, fact));
let tv = params
.tensors_values
.by_name(name)
.or_else(|| params.tensors_values.by_input_ix(ix));
tmp.push(vec![crate::tensor::tensor_for_fact(&fact, None, tv)?.into()]);
} else {
bail!("Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended", name);
}
}
Ok((0..tmp[0].len()).map(|turn| tmp.iter().map(|t| t[turn].clone()).collect()).collect())
}
fn make_inputs(values: &[impl std::borrow::Borrow<TypedFact>]) -> TractResult<TVec<TValue>> {
values.iter().map(|v| tensor_for_fact(v.borrow(), None, None).map(|t| t.into())).collect()
}