tract_libcli/
tensor.rs

1use std::collections::HashSet;
2use std::io::{Read, Seek};
3use std::ops::Range;
4use std::str::FromStr;
5use std::sync::Mutex;
6
7use crate::model::Model;
8use tract_hir::internal::*;
9use tract_num_traits::Zero;
10
11#[derive(Debug, Default, Clone)]
12pub struct TensorsValues(pub Vec<TensorValues>);
13
14impl TensorsValues {
15    pub fn by_name(&self, name: &str) -> Option<&TensorValues> {
16        self.0.iter().find(|t| t.name.as_deref() == Some(name))
17    }
18    pub fn by_name_mut(&mut self, name: &str) -> Option<&mut TensorValues> {
19        self.0.iter_mut().find(|t| t.name.as_deref() == Some(name))
20    }
21    pub fn by_name_mut_with_default(&mut self, name: &str) -> &mut TensorValues {
22        if self.by_name_mut(name).is_none() {
23            self.add(TensorValues { name: Some(name.to_string()), ..TensorValues::default() });
24        }
25        self.by_name_mut(name).unwrap()
26    }
27
28    pub fn by_input_ix(&self, ix: usize) -> Option<&TensorValues> {
29        self.0.iter().find(|t| t.input_index == Some(ix))
30    }
31    pub fn by_input_ix_mut(&mut self, ix: usize) -> Option<&mut TensorValues> {
32        self.0.iter_mut().find(|t| t.input_index == Some(ix))
33    }
34    pub fn by_input_ix_mut_with_default(&mut self, ix: usize) -> &mut TensorValues {
35        if self.by_input_ix_mut(ix).is_none() {
36            self.add(TensorValues { input_index: Some(ix), ..TensorValues::default() });
37        }
38        self.by_input_ix_mut(ix).unwrap()
39    }
40
41    pub fn add(&mut self, other: TensorValues) {
42        let mut tensor = other.input_index.and_then(|ix| self.by_input_ix_mut(ix));
43
44        if tensor.is_none() {
45            tensor = other.name.as_deref().and_then(|ix| self.by_name_mut(ix))
46        }
47
48        if let Some(tensor) = tensor {
49            if tensor.fact.is_none() {
50                tensor.fact = other.fact;
51            }
52            if tensor.values.is_none() {
53                tensor.values = other.values;
54            }
55        } else {
56            self.0.push(other.clone());
57        };
58    }
59}
60
61#[derive(Debug, PartialEq, Clone, Default)]
62pub struct TensorValues {
63    pub input_index: Option<usize>,
64    pub output_index: Option<usize>,
65    pub name: Option<String>,
66    pub fact: Option<InferenceFact>,
67    pub values: Option<Vec<TValue>>,
68    pub random_range: Option<Range<f32>>,
69}
70
71fn parse_dt(dt: &str) -> TractResult<DatumType> {
72    Ok(match dt.to_lowercase().as_ref() {
73        "bool" => DatumType::Bool,
74        "f16" => DatumType::F16,
75        "f32" => DatumType::F32,
76        "f64" => DatumType::F64,
77        "i8" => DatumType::I8,
78        "i16" => DatumType::I16,
79        "i32" => DatumType::I32,
80        "i64" => DatumType::I64,
81        "u8" => DatumType::U8,
82        "u16" => DatumType::U16,
83        "u32" => DatumType::U32,
84        "u64" => DatumType::U64,
85        "tdim" => DatumType::TDim,
86        _ => bail!(
87            "Type of the input should be f16, f32, f64, i8, i16, i16, i32, u8, u16, u32, u64, TDim."
88        ),
89    })
90}
91
92pub fn parse_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
93    if size.is_empty() {
94        return Ok(InferenceFact::default());
95    }
96    parse_coma_spec(symbol_table, size)
97}
98
99pub fn parse_coma_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
100    let splits = size.split(',').collect::<Vec<_>>();
101
102    #[allow(clippy::literal_string_with_formatting_args)]
103    if splits.is_empty() {
104        bail!("The <size> argument should be formatted as {{size}},{{...}},{{type}}.");
105    }
106
107    let last = splits.last().unwrap();
108    let (datum_type, shape) = if let Ok(dt) = parse_dt(last) {
109        (Some(dt), &splits[0..splits.len() - 1])
110    } else {
111        (None, &*splits)
112    };
113
114    let shape = ShapeFactoid::closed(
115        shape
116            .iter()
117            .map(|&s| {
118                Ok(if s == "_" {
119                    GenericFactoid::Any
120                } else {
121                    GenericFactoid::Only(parse_tdim(symbol_table, s)?)
122                })
123            })
124            .collect::<TractResult<TVec<DimFact>>>()?,
125    );
126
127    if let Some(dt) = datum_type {
128        Ok(InferenceFact::dt_shape(dt, shape))
129    } else {
130        Ok(InferenceFact::shape(shape))
131    }
132}
133
134fn parse_values<T: Datum + FromStr>(shape: &[usize], it: Vec<&str>) -> TractResult<Tensor> {
135    let values = it
136        .into_iter()
137        .map(|v| v.parse::<T>().map_err(|_| format_err!("Failed to parse {}", v)))
138        .collect::<TractResult<Vec<T>>>()?;
139    Ok(tract_ndarray::Array::from_shape_vec(shape, values)?.into())
140}
141
142fn tensor_for_text_data(
143    symbol_table: &SymbolScope,
144    _filename: &str,
145    mut reader: impl Read,
146) -> TractResult<Tensor> {
147    let mut data = String::new();
148    reader.read_to_string(&mut data)?;
149
150    let mut lines = data.lines();
151    let proto = parse_spec(symbol_table, lines.next().context("Empty data file")?)?;
152    let shape = proto.shape.concretize().unwrap();
153
154    let values = lines.flat_map(|l| l.split_whitespace()).collect::<Vec<&str>>();
155
156    // We know there is at most one streaming dimension, so we can deduce the
157    // missing value with a simple division.
158    let product: usize = shape.iter().map(|o| o.to_usize().unwrap_or(1)).product();
159    let missing = values.len() / product;
160
161    let shape: Vec<_> = shape.iter().map(|d| d.to_usize().unwrap_or(missing)).collect();
162    dispatch_numbers!(parse_values(proto.datum_type.concretize().unwrap())(&*shape, values))
163}
164
165/// Parses the `data` command-line argument.
166pub fn for_data(
167    symbol_table: &SymbolScope,
168    filename: &str,
169    reader: impl Read + std::io::Seek,
170) -> TractResult<(Option<String>, InferenceFact)> {
171    #[allow(unused_imports)]
172    use std::convert::TryFrom;
173    if filename.ends_with(".pb") {
174        #[cfg(feature = "onnx")]
175        {
176            use tract_onnx::data_resolver::FopenDataResolver;
177            use tract_onnx::tensor::load_tensor;
178            let proto = ::tract_onnx::tensor::proto_from_reader(reader)?;
179            let tensor = load_tensor(&FopenDataResolver, &proto, None)?;
180            Ok((Some(proto.name.to_string()).filter(|s| !s.is_empty()), tensor.into()))
181        }
182        #[cfg(not(feature = "onnx"))]
183        {
184            panic!("Loading tensor from protobuf requires onnx features");
185        }
186    } else if filename.contains(".npz:") {
187        let mut tokens = filename.split(':');
188        let (_filename, inner) = (tokens.next().unwrap(), tokens.next().unwrap());
189        let mut npz = ndarray_npy::NpzReader::new(reader)?;
190        Ok((None, for_npz(&mut npz, inner)?.into()))
191    } else {
192        Ok((None, tensor_for_text_data(symbol_table, filename, reader)?.into()))
193    }
194}
195
196pub fn for_npz(
197    npz: &mut ndarray_npy::NpzReader<impl Read + Seek>,
198    name: &str,
199) -> TractResult<Tensor> {
200    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f32>, tract_ndarray::IxDyn>(name) {
201        return Ok(t.into_tensor());
202    }
203    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f64>, tract_ndarray::IxDyn>(name) {
204        return Ok(t.into_tensor());
205    }
206    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i8>, tract_ndarray::IxDyn>(name) {
207        return Ok(t.into_tensor());
208    }
209    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i16>, tract_ndarray::IxDyn>(name) {
210        return Ok(t.into_tensor());
211    }
212    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i32>, tract_ndarray::IxDyn>(name) {
213        return Ok(t.into_tensor());
214    }
215    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i64>, tract_ndarray::IxDyn>(name) {
216        return Ok(t.into_tensor());
217    }
218    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u8>, tract_ndarray::IxDyn>(name) {
219        return Ok(t.into_tensor());
220    }
221    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u16>, tract_ndarray::IxDyn>(name) {
222        return Ok(t.into_tensor());
223    }
224    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u32>, tract_ndarray::IxDyn>(name) {
225        return Ok(t.into_tensor());
226    }
227    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u64>, tract_ndarray::IxDyn>(name) {
228        return Ok(t.into_tensor());
229    }
230    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<bool>, tract_ndarray::IxDyn>(name) {
231        return Ok(t.into_tensor());
232    }
233    bail!("Can not extract tensor from {}", name);
234}
235
236pub fn for_string(
237    symbol_table: &SymbolScope,
238    value: &str,
239) -> TractResult<(Option<String>, InferenceFact)> {
240    let (name, value) = if value.contains(':') {
241        let mut splits = value.split(':');
242        (Some(splits.next().unwrap().to_string()), splits.next().unwrap())
243    } else {
244        (None, value)
245    };
246    if value.contains('=') {
247        let mut split = value.split('=');
248        let spec = parse_spec(symbol_table, split.next().unwrap())?;
249        let value = split.next().unwrap().split(',');
250        let dt =
251            spec.datum_type.concretize().context("Must specify type when giving tensor value")?;
252        let shape = spec
253            .shape
254            .as_concrete_finite()?
255            .context("Must specify concrete shape when giving tensor value")?;
256        let tensor = if dt == TDim::datum_type() {
257            let mut tensor = Tensor::zero::<TDim>(&shape)?;
258            let values =
259                value.map(|v| parse_tdim(symbol_table, v)).collect::<TractResult<Vec<_>>>()?;
260            tensor.as_slice_mut::<TDim>()?.iter_mut().zip(values).for_each(|(t, v)| *t = v);
261            tensor
262        } else {
263            dispatch_numbers!(parse_values(dt)(&*shape, value.collect()))?
264        };
265        Ok((name, tensor.into()))
266    } else {
267        Ok((name, parse_spec(symbol_table, value)?))
268    }
269}
270
271lazy_static::lazy_static! {
272    static ref MESSAGE_ONCE: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
273}
274
275fn info_once(msg: String) {
276    if MESSAGE_ONCE.lock().unwrap().insert(msg.clone()) {
277        info!("{msg}");
278    }
279}
280
281pub struct RunParams {
282    pub tensors_values: TensorsValues,
283    pub allow_random_input: bool,
284    pub allow_float_casts: bool,
285    pub symbols: SymbolValues,
286}
287
288pub struct RunTensors {
289    pub sources: Vec<TVec<TValue>>,
290    pub state_initializers: Vec<TValue>,
291}
292
293fn get_or_make_tensors(
294    model: &dyn Model,
295    params: &RunParams,
296    fact: TypedFact,
297    name: &str,
298    input_idx: usize,
299    target: &mut TVec<Vec<TValue>>,
300) -> TractResult<()> {
301    if let Some(mut value) = params
302        .tensors_values
303        .by_name(name)
304        .or_else(|| params.tensors_values.by_input_ix(input_idx))
305        .and_then(|t| t.values.clone())
306    {
307        if !value[0].datum_type().is_quantized()
308            && fact.datum_type.is_quantized()
309            && value[0].datum_type() == fact.datum_type.unquantized()
310        {
311            value = value
312                .iter()
313                .map(|v| {
314                    let mut v = v.clone().into_tensor();
315                    unsafe { v.set_datum_type(fact.datum_type) };
316                    v.into()
317                })
318                .collect();
319        }
320        if TypedFact::shape_and_dt_of(&value[0]).compatible_with(&fact) {
321            info!("Using fixed input for input called {} ({} turn(s))", name, value.len());
322            target.push(value.iter().map(|t| t.clone().into_tensor().into()).collect());
323        } else if fact.datum_type == f16::datum_type()
324            && value[0].datum_type() == f32::datum_type()
325            && params.allow_float_casts
326        {
327            info!("Casting input to F16 for input called {} ({} turn(s))", name, value.len());
328            target.push(
329                value.iter().map(|t| t.cast_to::<f16>().unwrap().into_owned().into()).collect(),
330            );
331        } else if value.len() == 1 && model.properties().contains_key("pulse.delay") {
332            let value = &value[0];
333            let input_pulse_axis = model
334                .properties()
335                .get("pulse.input_axes")
336                .context("Expect pulse.input_axes property")?
337                .cast_to::<i64>()?
338                .as_slice::<i64>()?[input_idx] as usize;
339            let input_pulse = fact.shape.get(input_pulse_axis).unwrap().to_usize().unwrap();
340            let input_len = value.shape()[input_pulse_axis];
341
342            // how many pulses do we need to push full result out ?
343            // guess by looking at len and delay of the first output
344            let output_pulse_axis = model
345                .properties()
346                .get("pulse.output_axes")
347                .context("Expect pulse.output_axes property")?
348                .cast_to::<i64>()?
349                .as_slice::<i64>()?[0] as usize;
350            let output_fact = model.outlet_typedfact(model.output_outlets()[0])?;
351            let output_pulse =
352                output_fact.shape.get(output_pulse_axis).unwrap().to_usize().unwrap();
353            let output_len = input_len * output_pulse / input_pulse;
354            let output_delay = model.properties()["pulse.delay"].as_slice::<i64>()?[0] as usize;
355            let last_frame = output_len + output_delay;
356            let needed_pulses = last_frame.divceil(output_pulse);
357            let mut values = vec![];
358            for ix in 0..needed_pulses {
359                let mut t = Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
360                let start = ix * input_pulse;
361                let end = (start + input_pulse).min(input_len);
362                if end > start {
363                    t.assign_slice(0..end - start, value, start..end, input_pulse_axis)?;
364                }
365                values.push(t.into());
366            }
367            info!(
368                "Generated {} pulses of shape {:?} for input {}.",
369                needed_pulses, fact.shape, input_idx
370            );
371            target.push(values);
372        } else {
373            bail!(
374                "For input {}, can not reconcile model input fact {:?} with provided input {:?}",
375                name,
376                fact,
377                value[0]
378            );
379        };
380    } else if fact.shape.is_concrete() && fact.shape.volume() == TDim::zero() {
381        let shape = fact.shape.as_concrete().unwrap();
382        let tensor = Tensor::zero_dt(fact.datum_type, shape)?;
383        target.push(vec![tensor.into()]);
384    } else if params.allow_random_input {
385        info_once(format!("Using random input for input called {name:?}: {fact:?}"));
386        let tv = params
387            .tensors_values
388            .by_name(name)
389            .or_else(|| params.tensors_values.by_input_ix(input_idx));
390        let mut fact = fact.clone();
391        fact.shape = fact.shape.iter().map(|dim| dim.eval(&params.symbols)).collect();
392        target.push(vec![tensor_for_fact(&fact, None, tv)?.into()]);
393    } else {
394        bail!(
395            "Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended",
396            name
397        );
398    }
399    Ok(())
400}
401
402pub fn get_or_make_inputs(tract: &dyn Model, params: &RunParams) -> TractResult<RunTensors> {
403    // Resolve source inputs
404    let mut tmp_inputs = tvec![];
405    for (ix, input) in tract.input_outlets().iter().enumerate() {
406        let fact = tract.outlet_typedfact(*input)?;
407        let name = tract.node_name(input.node);
408        get_or_make_tensors(tract, params, fact, name, ix, &mut tmp_inputs)?;
409    }
410
411    let n_turns = tmp_inputs.first().map_or(0, |t| t.len());
412    let sources = (0..n_turns)
413        .map(|i| tmp_inputs.iter().map(|t| t[i].clone()).collect::<TVec<_>>())
414        .collect::<Vec<_>>();
415
416    // Resolve state initializers (KV Cache, etc.)
417    let mut dummy_session_state = SessionState::default();
418    let state_initializers = (0..tract.nodes_len())
419        .filter_map(|id| {
420            tract
421                .node_op(id)
422                .state(&mut dummy_session_state, id)
423                .ok()
424                .flatten()
425                .and_then(|state| state.init_tensor_fact())
426                .map(|fact| {
427                    let mut tmp = tvec![];
428                    get_or_make_tensors(
429                        tract,
430                        params,
431                        fact,
432                        tract.node_name(id),
433                        usize::MAX,
434                        &mut tmp,
435                    )?;
436                    Ok(tmp.remove(0).remove(0))
437                })
438        })
439        .collect::<TractResult<Vec<_>>>()?;
440
441    Ok(RunTensors { sources, state_initializers })
442}
443
444fn make_inputs(values: &[impl std::borrow::Borrow<TypedFact>]) -> TractResult<TVec<TValue>> {
445    values.iter().map(|v| tensor_for_fact(v.borrow(), None, None).map(|t| t.into())).collect()
446}
447
448pub fn make_inputs_for_model(model: &dyn Model) -> TractResult<TVec<TValue>> {
449    make_inputs(
450        &model
451            .input_outlets()
452            .iter()
453            .map(|&t| model.outlet_typedfact(t))
454            .collect::<TractResult<Vec<TypedFact>>>()?,
455    )
456}
457
458#[allow(unused_variables)]
459pub fn tensor_for_fact(
460    fact: &TypedFact,
461    streaming_dim: Option<usize>,
462    tv: Option<&TensorValues>,
463) -> TractResult<Tensor> {
464    if let Some(value) = &fact.konst {
465        return Ok(value.clone().into_tensor());
466    }
467    Ok(random(
468        fact.shape
469            .as_concrete()
470            .with_context(|| format!("Expected concrete shape, found: {fact:?}"))?,
471        fact.datum_type,
472        tv,
473    ))
474}
475
476/// Generates a random tensor of a given size and type.
477pub fn random(sizes: &[usize], datum_type: DatumType, tv: Option<&TensorValues>) -> Tensor {
478    use rand::{Rng, SeedableRng};
479    let mut rng = rand::rngs::StdRng::seed_from_u64(21242);
480    let mut tensor = Tensor::zero::<f32>(sizes).unwrap();
481    let slice = tensor.as_slice_mut::<f32>().unwrap();
482    if let Some(range) = tv.and_then(|tv| tv.random_range.as_ref()) {
483        slice.iter_mut().for_each(|x| *x = rng.gen_range(range.clone()))
484    } else {
485        slice.iter_mut().for_each(|x| *x = rng.r#gen())
486    };
487    tensor.cast_to_dt(datum_type).unwrap().into_owned()
488}