tract_core/ops/array/
range.rs1use crate::ops::cast::Cast;
2use tract_num_traits::AsPrimitive;
3use tract_num_traits::Zero;
4
5use crate::internal::*;
6
7use super::Slice;
8
9#[derive(Debug, Default, Clone, new, Hash)]
10pub struct Range {
11 len: TDim,
12}
13
14impl Op for Range {
15 fn name(&self) -> StaticName {
16 "Range".into()
17 }
18
19 op_as_typed_op!();
20}
21
22impl EvalOp for Range {
23 fn is_stateless(&self) -> bool {
24 true
25 }
26
27 fn eval_with_session(
28 &self,
29 _node_id: usize,
30 session: &TurnState,
31 inputs: TVec<TValue>,
32 ) -> TractResult<TVec<TValue>> {
33 let (start, end, step) = args_3!(inputs);
34 Ok(tvec!(self.make(&start, &end, &step, &session.resolved_symbols)?.into_tvalue()))
35 }
36}
37
38impl Range {
39 fn make_t<T: Datum + for<'a> std::ops::Add<&'a T, Output = T>>(
40 start: &Tensor,
41 step: &Tensor,
42 len: usize,
43 ) -> TractResult<Tensor> {
44 unsafe {
45 let mut result = Tensor::uninitialized::<T>(&[len])?;
46 let mut v = start.try_as_dense()?.to_scalar::<T>()?.clone();
47 let step = step.try_as_dense()?.to_scalar::<T>()?;
48 {
49 let mut result_dense = result.try_as_dense_mut()?;
50 for i in 0..len {
51 result_dense.as_slice_mut_unchecked::<T>()[i] = v.clone();
52 v = v + step;
53 }
54 }
55 Ok(result)
56 }
57 }
58
59 fn make(
60 &self,
61 start: &Tensor,
62 end: &Tensor,
63 step: &Tensor,
64 values: &SymbolValues,
65 ) -> TractResult<Tensor> {
66 if start.datum_type() == TDim::datum_type() {
67 let start = start.try_as_dense()?.to_scalar::<TDim>()?.eval(values).to_i64()?;
68 let step = step.try_as_dense()?.to_scalar::<TDim>()?.eval(values).to_i64()?;
69 let len = {
70 let end = end.try_as_dense()?.to_scalar::<TDim>()?.eval(values).to_i64()?;
71 #[allow(clippy::cast_abs_to_unsigned)]
72 ((end - start).abs() as usize).divceil(step.abs() as usize)
73 };
74 Self::make_t::<i64>(&tensor0(start), &tensor0(step), len)
75 } else {
76 let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
77 self, start, end, step
78 ))?;
79 dispatch_numbers!(Self::make_t(start.datum_type())(start, step, len))
80 }
81 }
82
83 fn len_for_numbers<T: Datum + AsPrimitive<f64>>(
84 &self,
85 start: &Tensor,
86 end: &Tensor,
87 step: &Tensor,
88 ) -> TractResult<usize> {
89 let start = start.try_as_dense()?.to_scalar::<T>()?;
90 let end = end.try_as_dense()?.to_scalar::<T>()?;
91 let step = step.try_as_dense()?.to_scalar::<T>()?;
92 Ok(((end.as_() - start.as_()) / (step.as_())).ceil() as usize)
93 }
94}
95
96impl TypedOp for Range {
97 fn declutter(
98 &self,
99 model: &TypedModel,
100 node: &TypedNode,
101 ) -> TractResult<Option<TypedModelPatch>> {
102 rule_if_some!(succ = model.single_succ(node.id)?);
103 rule_if_some!(slice = succ.op_as::<Slice>());
104 rule_if!(slice.start.is_zero());
105 rule_if!(slice.end.is_zero());
106
107 let mut patch = TypedModelPatch::default();
108 let mut wire = patch.tap_model(model, node.inputs[0])?;
109 if model.outlet_fact(node.inputs[0])?.datum_type.is_tdim() {
110 wire = patch.wire_node(
111 format!("{}.cast-tdim", node.name),
112 Cast { to: DatumType::I64 },
113 &[wire],
114 )?[0];
115 }
116 let wire = patch.wire_node(&node.name, AxisOp::Add(0), &[wire])?;
117 patch.shunt_outside(model, succ.id.into(), wire[0])?;
118 Ok(Some(patch))
119 }
120
121 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
122 let [start, end, step] = inputs else {
123 bail!("Expects three inputs");
124 };
125 ensure!(start.datum_type() == end.datum_type());
126 ensure!(start.datum_type() == step.datum_type());
127 ensure!(start.shape.volume().is_one());
128 ensure!(end.shape.volume().is_one());
129 ensure!(step.shape.volume().is_one());
130 if let (Some(start), Some(end), Some(step)) = (&start.konst, &end.konst, &step.konst) {
131 if start.datum_type() == TDim::datum_type() {
132 let start = start.try_as_dense()?.to_scalar::<TDim>()?;
133 let end = end.try_as_dense()?.to_scalar::<TDim>()?;
134 let step = step.cast_to_scalar::<i64>()?;
135 let len = if step < 0 {
136 (start.clone() - end).divceil(-step as usize)
137 } else {
138 (end.clone() - start).divceil(step as usize)
139 };
140 Ok(tvec!(DatumType::I64.fact([len])))
141 } else {
142 let len = dispatch_numbers!(Self::len_for_numbers(start.datum_type())(
143 self, start, end, step
144 ))?
145 .to_dim();
146 Ok(tvec!(start.datum_type().fact([len])))
147 }
148 } else {
149 Ok(tvec!(start.datum_type.fact(std::slice::from_ref(&self.len))))
150 }
151 }
152
153 as_op!();
154}