tract_core/ops/array/
range.rs

1use crate::ops::cast::Cast;
2use tract_num_traits::AsPrimitive;
3use tract_num_traits::Zero;
4
5use crate::internal::*;
6
7use super::Slice;
8
9#[derive(Debug, Default, Clone, new, Hash)]
10pub struct Range {
11    len: TDim,
12}
13
14impl Op for Range {
15    fn name(&self) -> Cow<str> {
16        "Range".into()
17    }
18
19    op_as_typed_op!();
20}
21
22impl EvalOp for Range {
23    fn is_stateless(&self) -> bool {
24        true
25    }
26
27    fn eval_with_session(
28        &self,
29        session: &SessionState,
30        inputs: TVec<TValue>,
31    ) -> TractResult<TVec<TValue>> {
32        let (start, end, step) = args_3!(inputs);
33        Ok(tvec!(self.make(&start, &end, &step, &session.resolved_symbols)?.into_tvalue()))
34    }
35}
36
37impl Range {
38    fn make_t<T: Datum + for<'a> std::ops::Add<&'a T, Output = T>>(
39        start: &Tensor,
40        step: &Tensor,
41        len: usize,
42    ) -> TractResult<Tensor> {
43        unsafe {
44            let mut result = Tensor::uninitialized::<T>(&[len])?;
45            let mut v = start.to_scalar::<T>()?.clone();
46            let step = step.to_scalar::<T>()?;
47            for i in 0..len {
48                result.as_slice_mut_unchecked::<T>()[i] = v.clone();
49                v = v + step;
50            }
51            Ok(result)
52        }
53    }
54
55    fn make(
56        &self,
57        start: &Tensor,
58        end: &Tensor,
59        step: &Tensor,
60        values: &SymbolValues,
61    ) -> TractResult<Tensor> {
62        if start.datum_type() == TDim::datum_type() {
63            let start = start.to_scalar::<TDim>()?.eval(values).to_i64()?;
64            let step = step.to_scalar::<TDim>()?.eval(values).to_i64()?;
65            let len = {
66                let end = end.to_scalar::<TDim>()?.eval(values).to_i64()?;
67                #[allow(clippy::cast_abs_to_unsigned)]
68                ((end - start).abs() as usize).divceil(step.abs() as usize)
69            };
70            Self::make_t::<i64>(&tensor0(start), &tensor0(step), len)
71        } else {
72            let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
73                self, start, end, step
74            ))?;
75            dispatch_numbers!(Self::make_t(start.datum_type())(start, step, len))
76        }
77    }
78
79    fn len_for_numbers<T: Datum + AsPrimitive<f64>>(
80        &self,
81        start: &Tensor,
82        end: &Tensor,
83        step: &Tensor,
84    ) -> TractResult<usize> {
85        let start = start.to_scalar::<T>()?;
86        let end = end.to_scalar::<T>()?;
87        let step = step.to_scalar::<T>()?;
88        Ok(((end.as_() - start.as_()) / (step.as_())).ceil() as usize)
89    }
90}
91
92impl TypedOp for Range {
93    fn declutter(
94        &self,
95        model: &TypedModel,
96        node: &TypedNode,
97    ) -> TractResult<Option<TypedModelPatch>> {
98        let Some(succ) = model.single_succ(node.id)? else { return Ok(None) };
99        let Some(slice) = succ.op_as::<Slice>() else { return Ok(None) };
100        if slice.start.is_zero() && slice.end.is_one() {
101            let mut patch = TypedModelPatch::default();
102            let mut wire = patch.tap_model(model, node.inputs[0])?;
103            if model.outlet_fact(node.inputs[0])?.datum_type.is_tdim() {
104                wire = patch.wire_node(
105                    format!("{}.cast-tdim", node.name),
106                    Cast { to: DatumType::I64 },
107                    &[wire],
108                )?[0];
109            }
110            let wire = patch.wire_node(&node.name, AxisOp::Add(0), &[wire])?;
111            patch.shunt_outside(model, succ.id.into(), wire[0])?;
112            return Ok(Some(patch));
113        }
114        Ok(None)
115    }
116
117    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
118        let [start, end, step] = inputs else {
119            bail!("Expects three inputs");
120        };
121        ensure!(start.datum_type() == end.datum_type());
122        ensure!(start.datum_type() == step.datum_type());
123        ensure!(start.shape.volume().is_one());
124        ensure!(end.shape.volume().is_one());
125        ensure!(step.shape.volume().is_one());
126        if let (Some(start), Some(end), Some(step)) = (&start.konst, &end.konst, &step.konst) {
127            if start.datum_type() == TDim::datum_type() {
128                let start = start.to_scalar::<TDim>()?;
129                let end = end.to_scalar::<TDim>()?;
130                let step = step.cast_to_scalar::<i64>()?;
131                let len = if step < 0 {
132                    (start.clone() - end).divceil(-step as usize)
133                } else {
134                    (end.clone() - start).divceil(step as usize)
135                };
136                Ok(tvec!(DatumType::I64.fact([len])))
137            } else {
138                let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
139                    self, start, end, step
140                ))?
141                .to_dim();
142                Ok(tvec!(start.datum_type().fact([len])))
143            }
144        } else {
145            Ok(tvec!(start.datum_type.fact(&[self.len.clone()])))
146        }
147    }
148
149    as_op!();
150}