tract_onnx/ops/array/
shape.rs1use std::ops::Range;
2
3use crate::model::ParsingContext;
4use crate::pb::NodeProto;
5use tract_hir::internal::*;
6
7pub fn shape(
8 _ctx: &ParsingContext,
9 node: &NodeProto,
10) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
11 let start = node.get_attr_opt("start")?.unwrap_or(0);
12 let end = node.get_attr_opt("end")?;
13 Ok((expand(Shape { start, end }), vec![]))
14}
15
16#[derive(Debug, Clone, new, Default, Hash)]
17struct Shape {
18 start: i64,
19 end: Option<i64>,
20}
21
22
23
24impl Shape {
25 fn resolve(&self, rank: i64) -> Range<usize> {
26 let start =
27 if self.start >= 0 { self.start } else { (rank + self.start).clamp(0, rank) } as usize;
28 let end = if let Some(end) = self.end {
29 if end >= 0 {
30 end
31 } else {
32 end + rank
33 }
34 } else {
35 rank
36 }
37 .clamp(0, rank) as usize;
38 start..end
39 }
40}
41
42impl Expansion for Shape {
43 fn name(&self) -> StaticName {
44 "Shape".into()
45 }
46
47 fn rules<'r, 'p: 'r, 's: 'r>(
48 &'s self,
49 s: &mut Solver<'r>,
50 inputs: &'p [TensorProxy],
51 outputs: &'p [TensorProxy],
52 ) -> InferenceResult {
53 check_input_arity(inputs, 1)?;
54 check_output_arity(outputs, 1)?;
55 s.equals(&outputs[0].rank, 1)?;
56 s.equals(&outputs[0].datum_type, TDim::datum_type())?;
57 s.given(&inputs[0].shape, |s, shape| {
58 let rank = shape.len() as i64;
59 let range = self.resolve(rank);
60 s.equals(&outputs[0].value, rctensor1(&shape[range]))?;
61 Ok(())
62 })
63 }
64
65 fn wire(
66 &self,
67 prefix: &str,
68 model: &mut TypedModel,
69 inputs: &[OutletId],
70 ) -> TractResult<TVec<OutletId>> {
71 let fact = model.outlet_fact(inputs[0])?;
72 let range = self.resolve(fact.rank() as i64);
73 let shape = fact.shape.to_tvec();
74 let wire = model.add_const(prefix, tensor1(&shape[range]))?;
75 Ok(tvec!(wire))
76 }
77}