tract_onnx/ops/array/
shape.rs

1use std::ops::Range;
2
3use crate::model::ParsingContext;
4use crate::pb::NodeProto;
5use tract_hir::internal::*;
6
7pub fn shape(
8    _ctx: &ParsingContext,
9    node: &NodeProto,
10) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
11    let start = node.get_attr_opt("start")?.unwrap_or(0);
12    let end = node.get_attr_opt("end")?;
13    Ok((expand(Shape { start, end }), vec![]))
14}
15
16#[derive(Debug, Clone, new, Default, Hash)]
17struct Shape {
18    start: i64,
19    end: Option<i64>,
20}
21
22
23
24impl Shape {
25    fn resolve(&self, rank: i64) -> Range<usize> {
26        let start =
27            if self.start >= 0 { self.start } else { (rank + self.start).clamp(0, rank) } as usize;
28        let end = if let Some(end) = self.end {
29            if end >= 0 {
30                end
31            } else {
32                end + rank
33            }
34        } else {
35            rank
36        }
37        .clamp(0, rank) as usize;
38        start..end
39    }
40}
41
42impl Expansion for Shape {
43    fn name(&self) -> StaticName {
44        "Shape".into()
45    }
46
47    fn rules<'r, 'p: 'r, 's: 'r>(
48        &'s self,
49        s: &mut Solver<'r>,
50        inputs: &'p [TensorProxy],
51        outputs: &'p [TensorProxy],
52    ) -> InferenceResult {
53        check_input_arity(inputs, 1)?;
54        check_output_arity(outputs, 1)?;
55        s.equals(&outputs[0].rank, 1)?;
56        s.equals(&outputs[0].datum_type, TDim::datum_type())?;
57        s.given(&inputs[0].shape, |s, shape| {
58            let rank = shape.len() as i64;
59            let range = self.resolve(rank);
60            s.equals(&outputs[0].value, rctensor1(&shape[range]))?;
61            Ok(())
62        })
63    }
64
65    fn wire(
66        &self,
67        prefix: &str,
68        model: &mut TypedModel,
69        inputs: &[OutletId],
70    ) -> TractResult<TVec<OutletId>> {
71        let fact = model.outlet_fact(inputs[0])?;
72        let range = self.resolve(fact.rank() as i64);
73        let shape = fact.shape.to_tvec();
74        let wire = model.add_const(prefix, tensor1(&shape[range]))?;
75        Ok(tvec!(wire))
76    }
77}