Skip to main content

tract_libcli/
tensor.rs

1use std::collections::HashSet;
2use std::io::{Read, Seek};
3use std::ops::Range;
4use std::str::FromStr;
5use std::sync::Mutex;
6
7use crate::model::Model;
8use tract_hir::internal::*;
9use tract_num_traits::Zero;
10
11#[cfg(feature = "transformers")]
12use tract_transformers::figure_out_causal_llm_b_s_p;
13
14#[derive(Debug, Default, Clone)]
15pub struct TensorsValues(pub Vec<TensorValues>);
16
17impl TensorsValues {
18    pub fn by_name(&self, name: &str) -> Option<&TensorValues> {
19        self.0.iter().find(|t| t.name.as_deref() == Some(name))
20    }
21    pub fn by_name_mut(&mut self, name: &str) -> Option<&mut TensorValues> {
22        self.0.iter_mut().find(|t| t.name.as_deref() == Some(name))
23    }
24    pub fn by_name_mut_with_default(&mut self, name: &str) -> &mut TensorValues {
25        if self.by_name_mut(name).is_none() {
26            self.add(TensorValues { name: Some(name.to_string()), ..TensorValues::default() });
27        }
28        self.by_name_mut(name).unwrap()
29    }
30
31    pub fn by_input_ix(&self, ix: usize) -> Option<&TensorValues> {
32        self.0.iter().find(|t| t.input_index == Some(ix))
33    }
34    pub fn by_input_ix_mut(&mut self, ix: usize) -> Option<&mut TensorValues> {
35        self.0.iter_mut().find(|t| t.input_index == Some(ix))
36    }
37    pub fn by_input_ix_mut_with_default(&mut self, ix: usize) -> &mut TensorValues {
38        if self.by_input_ix_mut(ix).is_none() {
39            self.add(TensorValues { input_index: Some(ix), ..TensorValues::default() });
40        }
41        self.by_input_ix_mut(ix).unwrap()
42    }
43
44    pub fn add(&mut self, other: TensorValues) {
45        let mut tensor = other.input_index.and_then(|ix| self.by_input_ix_mut(ix));
46
47        if tensor.is_none() {
48            tensor = other.name.as_deref().and_then(|ix| self.by_name_mut(ix))
49        }
50
51        if let Some(tensor) = tensor {
52            if tensor.fact.is_none() {
53                tensor.fact = other.fact;
54            }
55            if tensor.values.is_none() {
56                tensor.values = other.values;
57            }
58        } else {
59            self.0.push(other.clone());
60        };
61    }
62
63    pub fn input_by_name(&self, name: &str) -> Option<&TensorValues> {
64        self.0
65            .iter()
66            .filter(|tv| tv.output_index.is_none() && !tv.only_output)
67            .find(|t| t.name.as_deref() == Some(name))
68    }
69}
70
71#[derive(Debug, PartialEq, Clone, Default)]
72pub struct TensorValues {
73    pub input_index: Option<usize>,
74    pub output_index: Option<usize>,
75    pub name: Option<String>,
76    pub fact: Option<InferenceFact>,
77    pub values: Option<Vec<TValue>>,
78    pub random_range: Option<Range<f32>>,
79    pub only_input: bool,
80    pub only_output: bool,
81}
82
83fn parse_dt(dt: &str) -> TractResult<DatumType> {
84    Ok(match dt.to_lowercase().as_ref() {
85        "bool" => DatumType::Bool,
86        "f16" => DatumType::F16,
87        "f32" => DatumType::F32,
88        "f64" => DatumType::F64,
89        "i8" => DatumType::I8,
90        "i16" => DatumType::I16,
91        "i32" => DatumType::I32,
92        "i64" => DatumType::I64,
93        "u8" => DatumType::U8,
94        "u16" => DatumType::U16,
95        "u32" => DatumType::U32,
96        "u64" => DatumType::U64,
97        "tdim" => DatumType::TDim,
98        _ => bail!(
99            "Type of the input should be f16, f32, f64, i8, i16, i16, i32, u8, u16, u32, u64, TDim."
100        ),
101    })
102}
103
104pub fn parse_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
105    if size.is_empty() {
106        return Ok(InferenceFact::default());
107    }
108    parse_coma_spec(symbol_table, size)
109}
110
111pub fn parse_coma_spec(symbol_table: &SymbolScope, size: &str) -> TractResult<InferenceFact> {
112    let splits = size.split(',').collect::<Vec<_>>();
113
114    #[allow(clippy::literal_string_with_formatting_args)]
115    if splits.is_empty() {
116        bail!("The <size> argument should be formatted as {{size}},{{...}},{{type}}.");
117    }
118
119    let last = splits.last().unwrap();
120    let (datum_type, shape) = if let Ok(dt) = parse_dt(last) {
121        (Some(dt), &splits[0..splits.len() - 1])
122    } else {
123        (None, &*splits)
124    };
125
126    let shape = ShapeFactoid::closed(
127        shape
128            .iter()
129            .map(|&s| {
130                Ok(if s == "_" {
131                    GenericFactoid::Any
132                } else {
133                    GenericFactoid::Only(parse_tdim(symbol_table, s)?)
134                })
135            })
136            .collect::<TractResult<TVec<DimFact>>>()?,
137    );
138
139    if let Some(dt) = datum_type {
140        Ok(InferenceFact::dt_shape(dt, shape))
141    } else {
142        Ok(InferenceFact::shape(shape))
143    }
144}
145
146fn parse_values<T: Datum + FromStr>(shape: &[usize], it: Vec<&str>) -> TractResult<Tensor> {
147    let values = it
148        .into_iter()
149        .map(|v| v.parse::<T>().map_err(|_| format_err!("Failed to parse {}", v)))
150        .collect::<TractResult<Vec<T>>>()?;
151    Ok(tract_ndarray::Array::from_shape_vec(shape, values)?.into())
152}
153
154fn tensor_for_text_data(
155    symbol_table: &SymbolScope,
156    _filename: &str,
157    mut reader: impl Read,
158) -> TractResult<Tensor> {
159    let mut data = String::new();
160    reader.read_to_string(&mut data)?;
161
162    let mut lines = data.lines();
163    let proto = parse_spec(symbol_table, lines.next().context("Empty data file")?)?;
164    let shape = proto.shape.concretize().unwrap();
165
166    let values = lines.flat_map(|l| l.split_whitespace()).collect::<Vec<&str>>();
167
168    // We know there is at most one streaming dimension, so we can deduce the
169    // missing value with a simple division.
170    let product: usize = shape.iter().map(|o| o.to_usize().unwrap_or(1)).product();
171    let missing = values.len() / product;
172
173    let shape: Vec<_> = shape.iter().map(|d| d.to_usize().unwrap_or(missing)).collect();
174    dispatch_numbers!(parse_values(proto.datum_type.concretize().unwrap())(&*shape, values))
175}
176
177/// Parses the `data` command-line argument.
178pub fn for_data(
179    symbol_table: &SymbolScope,
180    filename: &str,
181    reader: impl Read + std::io::Seek,
182) -> TractResult<(Option<String>, InferenceFact)> {
183    #[allow(unused_imports)]
184    use std::convert::TryFrom;
185    if filename.ends_with(".pb") {
186        #[cfg(feature = "onnx")]
187        {
188            use tract_onnx::data_resolver::FopenDataResolver;
189            use tract_onnx::tensor::load_tensor;
190            let proto = ::tract_onnx::tensor::proto_from_reader(reader)?;
191            let tensor = load_tensor(&FopenDataResolver, &proto, None)?;
192            Ok((Some(proto.name.to_string()).filter(|s| !s.is_empty()), tensor.into()))
193        }
194        #[cfg(not(feature = "onnx"))]
195        {
196            panic!("Loading tensor from protobuf requires onnx features");
197        }
198    } else if filename.contains(".npz:") {
199        let mut tokens = filename.split(':');
200        let (_filename, inner) = (tokens.next().unwrap(), tokens.next().unwrap());
201        let mut npz = ndarray_npy::NpzReader::new(reader)?;
202        Ok((None, for_npz(&mut npz, inner)?.into()))
203    } else {
204        Ok((None, tensor_for_text_data(symbol_table, filename, reader)?.into()))
205    }
206}
207
208pub fn for_npz(
209    npz: &mut ndarray_npy::NpzReader<impl Read + Seek>,
210    name: &str,
211) -> TractResult<Tensor> {
212    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f32>, tract_ndarray::IxDyn>(name) {
213        return Ok(t.into_tensor());
214    }
215    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<f64>, tract_ndarray::IxDyn>(name) {
216        return Ok(t.into_tensor());
217    }
218    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i8>, tract_ndarray::IxDyn>(name) {
219        return Ok(t.into_tensor());
220    }
221    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i16>, tract_ndarray::IxDyn>(name) {
222        return Ok(t.into_tensor());
223    }
224    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i32>, tract_ndarray::IxDyn>(name) {
225        return Ok(t.into_tensor());
226    }
227    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<i64>, tract_ndarray::IxDyn>(name) {
228        return Ok(t.into_tensor());
229    }
230    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u8>, tract_ndarray::IxDyn>(name) {
231        return Ok(t.into_tensor());
232    }
233    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u16>, tract_ndarray::IxDyn>(name) {
234        return Ok(t.into_tensor());
235    }
236    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u32>, tract_ndarray::IxDyn>(name) {
237        return Ok(t.into_tensor());
238    }
239    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<u64>, tract_ndarray::IxDyn>(name) {
240        return Ok(t.into_tensor());
241    }
242    if let Ok(t) = npz.by_name::<tract_ndarray::OwnedRepr<bool>, tract_ndarray::IxDyn>(name) {
243        return Ok(t.into_tensor());
244    }
245    bail!("Can not extract tensor from {}", name);
246}
247
248pub fn for_string(
249    symbol_table: &SymbolScope,
250    value: &str,
251) -> TractResult<(Option<String>, InferenceFact)> {
252    let (name, value) = if value.contains(':') {
253        let mut splits = value.split(':');
254        (Some(splits.next().unwrap().to_string()), splits.next().unwrap())
255    } else {
256        (None, value)
257    };
258    if value.contains('=') {
259        let mut split = value.split('=');
260        let spec = parse_spec(symbol_table, split.next().unwrap())?;
261        let value = split.next().unwrap().split(',');
262        let dt =
263            spec.datum_type.concretize().context("Must specify type when giving tensor value")?;
264        let shape = spec
265            .shape
266            .as_concrete_finite()?
267            .context("Must specify concrete shape when giving tensor value")?;
268        let tensor = if dt == TDim::datum_type() {
269            let mut tensor = Tensor::zero::<TDim>(&shape)?;
270            let values =
271                value.map(|v| parse_tdim(symbol_table, v)).collect::<TractResult<Vec<_>>>()?;
272            tensor
273                .try_as_plain_mut()?
274                .as_slice_mut::<TDim>()?
275                .iter_mut()
276                .zip(values)
277                .for_each(|(t, v)| *t = v);
278            tensor
279        } else {
280            dispatch_numbers!(parse_values(dt)(&*shape, value.collect()))?
281        };
282        Ok((name, tensor.into()))
283    } else {
284        Ok((name, parse_spec(symbol_table, value)?))
285    }
286}
287
288lazy_static::lazy_static! {
289    static ref MESSAGE_ONCE: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
290}
291
292fn info_once(msg: String) {
293    if MESSAGE_ONCE.lock().unwrap().insert(msg.clone()) {
294        info!("{msg}");
295    }
296}
297
298pub struct RunParams {
299    pub tensors_values: TensorsValues,
300    pub allow_random_input: bool,
301    pub allow_float_casts: bool,
302    pub symbols: SymbolValues,
303    pub prompt_chunk_size: Option<usize>,
304    pub drop_partial_pulse: bool,
305}
306
307pub struct RunTensors {
308    pub sources: Vec<TVec<TValue>>,
309    /// In pulse mode, the *real* input length on the streaming axis
310    /// (i.e. before the trailing turns of zero-padding that
311    /// `get_or_make_tensors` adds to flush the pipeline).  Used by the
312    /// runner to bind the streaming symbol so PulsePad and friends
313    /// resolve `end_input` correctly at end-of-stream.  `None` for
314    /// non-pulse runs.
315    pub streaming_input_len: Option<usize>,
316}
317
318#[cfg(feature = "transformers")]
319fn chunk_fact(
320    fact: &TypedFact,
321    params: &RunParams,
322    model: &Arc<dyn Model>,
323) -> TractResult<Vec<TypedFact>> {
324    let Some(chunk_size) = params.prompt_chunk_size else {
325        return Ok(vec![fact.clone()]);
326    };
327    let Some(model) = model.downcast_ref::<TypedModel>() else {
328        return Ok(vec![fact.clone()]);
329    };
330    let (_, s, _) = figure_out_causal_llm_b_s_p(model)?;
331    let Some(s) = s else {
332        return Ok(vec![fact.clone()]);
333    };
334
335    let dims = fact.shape.dims();
336    let Some(sym_idx) = dims.iter().position(|d| *d == TDim::Sym(s.clone())) else {
337        return Ok(vec![fact.clone()]);
338    };
339
340    let resolved_sym = dims[sym_idx].eval_to_i64(&params.symbols)? as usize;
341    if resolved_sym <= chunk_size {
342        return Ok(vec![fact.clone()]);
343    }
344
345    let num_chunks = resolved_sym.div_ceil(chunk_size);
346    let mut out = Vec::with_capacity(num_chunks);
347
348    for start in (0..resolved_sym).step_by(chunk_size) {
349        let this = chunk_size.min(resolved_sym - start) as i64;
350
351        let mut new_fact = fact.clone();
352        new_fact.shape = new_fact
353            .shape
354            .iter()
355            .enumerate()
356            .map(|(i, d)| if i == sym_idx { TDim::Val(this) } else { d.eval(&params.symbols) })
357            .collect();
358
359        out.push(new_fact);
360    }
361
362    Ok(out)
363}
364
365#[cfg(feature = "transformers")]
366fn chunk_tensor(
367    tensor: Tensor,
368    fact: &TypedFact,
369    params: &RunParams,
370    model: &Arc<dyn Model>,
371) -> TractResult<Vec<TValue>> {
372    let Some(chunk_size) = params.prompt_chunk_size else {
373        return Ok(vec![tensor.into_tvalue()]);
374    };
375
376    let Some(model) = model.downcast_ref::<TypedModel>() else {
377        return Ok(vec![tensor.into_tvalue()]);
378    };
379    let (_, s, _) = figure_out_causal_llm_b_s_p(model)?;
380    let Some(s) = s else {
381        return Ok(vec![tensor.into_tvalue()]);
382    };
383
384    let dims = fact.shape.dims();
385    let Some(symb_axis) = dims.iter().position(|d| *d == TDim::Sym(s.clone())) else {
386        return Ok(vec![tensor.into_tvalue()]);
387    };
388
389    let resolved_sym = tensor.shape()[symb_axis];
390    if resolved_sym <= chunk_size {
391        return Ok(vec![tensor.into_tvalue()]);
392    }
393
394    let num_chunks = resolved_sym.div_ceil(chunk_size);
395    let mut out = Vec::with_capacity(num_chunks);
396
397    for start in (0..resolved_sym).step_by(chunk_size) {
398        let this = chunk_size.min(resolved_sym - start);
399        out.push(tensor.slice(symb_axis, start, start + this)?.into_tvalue());
400    }
401
402    Ok(out)
403}
404
405fn get_or_make_tensors(
406    model: &Arc<dyn Model>,
407    params: &RunParams,
408    fact: TypedFact,
409    name: &str,
410    input_idx: usize,
411    target: &mut TVec<Vec<TValue>>,
412    streaming_input_len: &mut Option<usize>,
413) -> TractResult<()> {
414    if let Some(mut value) = params
415        .tensors_values
416        .by_name(name)
417        .or_else(|| params.tensors_values.by_input_ix(input_idx))
418        .and_then(|t| t.values.clone())
419    {
420        if !value[0].datum_type().is_quantized()
421            && fact.datum_type.is_quantized()
422            && value[0].datum_type() == fact.datum_type.unquantized()
423        {
424            value = value
425                .iter()
426                .map(|v| {
427                    let mut v = v.clone().into_tensor();
428                    unsafe { v.set_datum_type(fact.datum_type) };
429                    v.into()
430                })
431                .collect();
432        }
433        let mut chunked_tensors: Vec<TValue> = vec![];
434        for t in &value {
435            let tensor = if TypedFact::shape_and_dt_of(&value[0]).compatible_with(&fact) {
436                info_once(format!(
437                    "Using fixed input for input called {} ({} turn(s))",
438                    name,
439                    value.len()
440                ));
441                t.clone().into_tensor()
442            } else if fact.datum_type == f16::datum_type()
443                && value[0].datum_type() == f32::datum_type()
444                && params.allow_float_casts
445            {
446                debug!("Casting input to F16 for input called {} ({} turn(s))", name, value.len());
447                t.cast_to::<f16>()?.into_owned()
448            } else {
449                break;
450            };
451
452            chunked_tensors.extend(chunk_tensor(tensor, &fact, params, model)?);
453        }
454        if !chunked_tensors.is_empty() {
455            target.push(chunked_tensors);
456            return Ok(());
457        }
458
459        if value.len() == 1 && model.properties().contains_key("pulse.delay") {
460            let value = &value[0];
461            let input_pulse_axis = model
462                .properties()
463                .get("pulse.input_axes")
464                .context("Expect pulse.input_axes property")?
465                .cast_to::<i64>()?
466                .try_as_plain()?
467                .as_slice::<i64>()?[input_idx] as usize;
468            let input_pulse = fact.shape.get(input_pulse_axis).unwrap().to_usize().unwrap();
469            let mut input_len = value.shape()[input_pulse_axis];
470            if params.drop_partial_pulse && input_len % input_pulse != 0 {
471                input_len = (input_len / input_pulse) * input_pulse;
472                info!(
473                    "Dropping partial trailing pulse: truncating input from {} to {} on axis {}.",
474                    value.shape()[input_pulse_axis],
475                    input_len,
476                    input_pulse_axis
477                );
478            }
479            // Record the real streaming-axis length on the *first* input we
480            // see; downstream uses it to bind the streaming symbol so that
481            // PulsePad's `end_input` resolves to the correct end-of-stream.
482            // (All inputs share the same streaming dim, so picking the first
483            // is fine.)
484            if streaming_input_len.is_none() {
485                *streaming_input_len = Some(input_len);
486            }
487
488            // how many pulses do we need to push full result out ?
489            // guess by looking at len and delay of the first output
490            let output_pulse_axis = model
491                .properties()
492                .get("pulse.output_axes")
493                .context("Expect pulse.output_axes property")?
494                .cast_to::<i64>()?
495                .try_as_plain()?
496                .as_slice::<i64>()?[0] as usize;
497            let output_fact = model.outlet_typedfact(model.output_outlets()[0])?;
498            let output_pulse =
499                output_fact.shape.get(output_pulse_axis).unwrap().to_usize().unwrap();
500            let output_len = input_len * output_pulse / input_pulse;
501            let output_delay =
502                model.properties()["pulse.delay"].try_as_plain()?.as_slice::<i64>()?[0] as usize;
503            let last_frame = output_len + output_delay;
504            let needed_pulses = last_frame.divceil(output_pulse);
505            let mut values = vec![];
506            for ix in 0..needed_pulses {
507                let mut t = Tensor::zero_dt(fact.datum_type, fact.shape.as_concrete().unwrap())?;
508                let start = ix * input_pulse;
509                let end = (start + input_pulse).min(input_len);
510                if end > start {
511                    t.assign_slice(0..end - start, value, start..end, input_pulse_axis)?;
512                }
513                values.push(t.into());
514            }
515            info!(
516                "Generated {} pulses of shape {:?} for input {}.",
517                needed_pulses, fact.shape, input_idx
518            );
519            target.push(values);
520        } else {
521            bail!(
522                "For input {}, can not reconcile model input fact {:?} with provided input {:?}",
523                name,
524                fact,
525                value[0]
526            );
527        };
528    } else if fact.shape.is_concrete() && fact.shape.volume() == TDim::zero() {
529        let shape = fact.shape.as_concrete().unwrap();
530        let tensor = Tensor::zero_dt(fact.datum_type, shape)?;
531        target.push(vec![tensor.into()]);
532    } else if params.allow_random_input {
533        info_once(format!("Using random input for input called {name:?}: {fact:?}"));
534        let tv = params
535            .tensors_values
536            .by_name(name)
537            .or_else(|| params.tensors_values.by_input_ix(input_idx));
538
539        let mut chunked_facts = chunk_fact(&fact, params, model)?;
540
541        let mut chunked_tensors = Vec::with_capacity(chunked_facts.len());
542        for fact in &mut chunked_facts {
543            fact.shape = fact.shape.iter().map(|dim| dim.eval(&params.symbols)).collect();
544            chunked_tensors.push(tensor_for_fact(fact, None, tv)?.into());
545        }
546        target.push(chunked_tensors);
547    } else {
548        bail!(
549            "Unmatched tensor {}. Fix the input or use \"--allow-random-input\" if this was intended",
550            name
551        );
552    }
553    Ok(())
554}
555
556pub fn get_or_make_inputs(tract: &Arc<dyn Model>, params: &RunParams) -> TractResult<RunTensors> {
557    // Resolve source inputs
558    let mut tmp_inputs = tvec![];
559    let mut streaming_input_len = None;
560    for (ix, input) in tract.input_outlets().iter().enumerate() {
561        let fact = tract.outlet_typedfact(*input)?;
562        let name = tract.node_name(input.node);
563        get_or_make_tensors(
564            tract,
565            params,
566            fact,
567            name,
568            ix,
569            &mut tmp_inputs,
570            &mut streaming_input_len,
571        )?;
572    }
573
574    let n_turns = tmp_inputs.iter().map(|t| t.len()).max().unwrap_or(0);
575    let sources = (0..n_turns)
576        .map(|i| {
577            tmp_inputs
578                .iter()
579                .map(|t| if i < t.len() { t[i].clone() } else { t[t.len() - 1].clone() })
580                .collect::<TVec<_>>()
581        })
582        .collect::<Vec<_>>();
583
584    Ok(RunTensors { sources, streaming_input_len })
585}
586
587fn make_inputs(values: &[impl std::borrow::Borrow<TypedFact>]) -> TractResult<TVec<TValue>> {
588    values.iter().map(|v| tensor_for_fact(v.borrow(), None, None).map(|t| t.into())).collect()
589}
590
591pub fn make_inputs_for_model(model: &dyn Model) -> TractResult<TVec<TValue>> {
592    make_inputs(
593        &model
594            .input_outlets()
595            .iter()
596            .map(|&t| model.outlet_typedfact(t))
597            .collect::<TractResult<Vec<TypedFact>>>()?,
598    )
599}
600
601#[allow(unused_variables)]
602pub fn tensor_for_fact(
603    fact: &TypedFact,
604    streaming_dim: Option<usize>,
605    tv: Option<&TensorValues>,
606) -> TractResult<Tensor> {
607    if let Some(value) = &fact.konst {
608        return Ok(value.clone().into_tensor());
609    }
610    Ok(random(
611        fact.shape
612            .as_concrete()
613            .with_context(|| format!("Expected concrete shape, found: {fact:?}"))?,
614        fact.datum_type,
615        tv,
616    ))
617}
618
619/// Generates a random tensor of a given size and type.
620pub fn random(sizes: &[usize], datum_type: DatumType, tv: Option<&TensorValues>) -> Tensor {
621    use rand::{RngExt, SeedableRng};
622    let mut rng = rand::rngs::StdRng::seed_from_u64(21242);
623    let mut tensor = Tensor::zero::<f32>(sizes).unwrap();
624    let mut tensor_plain = tensor.try_as_plain_mut().unwrap();
625    let slice = tensor_plain.as_slice_mut::<f32>().unwrap();
626    if let Some(range) = tv.and_then(|tv| tv.random_range.as_ref()) {
627        slice.iter_mut().for_each(|x| *x = rng.random_range(range.clone()))
628    } else {
629        slice.iter_mut().for_each(|x| *x = rng.random())
630    };
631    tensor.cast_to_dt(datum_type).unwrap().into_owned()
632}